Update PyO3 (#426)

This commit is contained in:
Anthony MOI
2020-09-22 12:00:20 -04:00
committed by GitHub
parent 8e220dbdd4
commit 940f8bd8fa
13 changed files with 156 additions and 178 deletions

View File

@ -484,14 +484,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]] [[package]]
name = "numpy" name = "numpy"
version = "0.11.0" version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01#e331befa27fede78d4662edf08fa0508db39be01"
dependencies = [ dependencies = [
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)", "libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)",
"ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)", "ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)",
"num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)", "num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)", "num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)", "pyo3 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
] ]
[[package]] [[package]]
@ -580,7 +580,7 @@ dependencies = [
[[package]] [[package]]
name = "pyo3" name = "pyo3"
version = "0.11.1" version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [ dependencies = [
"ctor 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)", "ctor 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)",
@ -589,13 +589,13 @@ dependencies = [
"libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)", "libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)",
"parking_lot 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", "parking_lot 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)",
"paste 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)", "paste 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3cls 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)", "pyo3cls 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
"unindent 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", "unindent 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
] ]
[[package]] [[package]]
name = "pyo3-derive-backend" name = "pyo3-derive-backend"
version = "0.11.1" version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [ dependencies = [
"proc-macro2 1.0.21 (registry+https://github.com/rust-lang/crates.io-index)", "proc-macro2 1.0.21 (registry+https://github.com/rust-lang/crates.io-index)",
@ -605,10 +605,10 @@ dependencies = [
[[package]] [[package]]
name = "pyo3cls" name = "pyo3cls"
version = "0.11.1" version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [ dependencies = [
"pyo3-derive-backend 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)", "pyo3-derive-backend 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)", "quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)", "syn 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)",
] ]
@ -895,8 +895,8 @@ dependencies = [
"env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)", "env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)", "libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)",
"ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)", "ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)",
"numpy 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", "numpy 0.11.0 (git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01)",
"pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)", "pyo3 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "rayon 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)", "serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)",
@ -1037,7 +1037,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611" "checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
"checksum num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" "checksum num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a" "checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
"checksum numpy 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1afa393812b0e4c33d8a46054028ce39d79f954dfa98298f25691abf00b24e39" "checksum numpy 0.11.0 (git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01)" = "<none>"
"checksum onig 6.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8a155d13862da85473665694f4c05d77fb96598bdceeaf696433c84ea9567e20" "checksum onig 6.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8a155d13862da85473665694f4c05d77fb96598bdceeaf696433c84ea9567e20"
"checksum onig_sys 69.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "9bff06597a6b17855040955cae613af000fc0bfc8ad49ea68b9479a74e59292d" "checksum onig_sys 69.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "9bff06597a6b17855040955cae613af000fc0bfc8ad49ea68b9479a74e59292d"
"checksum parking_lot 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a4893845fa2ca272e647da5d0e46660a314ead9c2fdd9a883aabc32e481a8733" "checksum parking_lot 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a4893845fa2ca272e647da5d0e46660a314ead9c2fdd9a883aabc32e481a8733"
@ -1048,9 +1048,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum ppv-lite86 0.2.9 (registry+https://github.com/rust-lang/crates.io-index)" = "c36fa947111f5c62a733b652544dd0016a43ce89619538a8ef92724a6f501a20" "checksum ppv-lite86 0.2.9 (registry+https://github.com/rust-lang/crates.io-index)" = "c36fa947111f5c62a733b652544dd0016a43ce89619538a8ef92724a6f501a20"
"checksum proc-macro-hack 0.5.18 (registry+https://github.com/rust-lang/crates.io-index)" = "99c605b9a0adc77b7211c6b1f722dcb613d68d66859a44f3d485a6da332b0598" "checksum proc-macro-hack 0.5.18 (registry+https://github.com/rust-lang/crates.io-index)" = "99c605b9a0adc77b7211c6b1f722dcb613d68d66859a44f3d485a6da332b0598"
"checksum proc-macro2 1.0.21 (registry+https://github.com/rust-lang/crates.io-index)" = "36e28516df94f3dd551a587da5357459d9b36d945a7c37c3557928c1c2ff2a2c" "checksum proc-macro2 1.0.21 (registry+https://github.com/rust-lang/crates.io-index)" = "36e28516df94f3dd551a587da5357459d9b36d945a7c37c3557928c1c2ff2a2c"
"checksum pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "9ca8710ffa8211c9a62a8a3863c4267c710dc42a82a7fd29c97de465d7ea6b7d" "checksum pyo3 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4fe7ed6655de90fd3bf799f25ce8b695e89eeb098d03db802723617649fcaa74"
"checksum pyo3-derive-backend 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "58ad070bf6967b0d29ea74931ffcf9c6bbe8402a726e9afbeafadc0a287cc2b3" "checksum pyo3-derive-backend 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d021655b5f22aeee2eaf3c3f88503250ad6cf90e37a875e1bb1854fceb86b3d0"
"checksum pyo3cls 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c3fa17e1ea569d0bf3b7c00f2a9eea831ca05e55dd76f1794c541abba1c64baa" "checksum pyo3cls 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "228ee516c00ac54fb70032d8cd928f213e70c87c340cdf44e2eba028d1f613ea"
"checksum quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" "checksum quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
"checksum quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)" = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37" "checksum quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)" = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37"
"checksum rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" "checksum rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"

View File

@ -14,12 +14,10 @@ serde = { version = "1.0", features = [ "rc", "derive" ]}
serde_json = "1.0" serde_json = "1.0"
libc = "0.2" libc = "0.2"
env_logger = "0.7.1" env_logger = "0.7.1"
numpy = "0.11" pyo3 = "0.12"
numpy = { git = "https://github.com/pyo3/rust-numpy/", rev = "e331befa27fede78d4662edf08fa0508db39be01" }
ndarray = "0.13" ndarray = "0.13"
[dependencies.pyo3]
version = "0.11"
[dependencies.tokenizers] [dependencies.tokenizers]
version = "*" version = "*"
path = "../../tokenizers" path = "../../tokenizers"

View File

@ -31,21 +31,15 @@ impl PyDecoder {
let base = self.clone(); let base = self.clone();
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
let py = gil.python(); let py = gil.python();
match &self.decoder { Ok(match &self.decoder {
PyDecoderWrapper::Custom(_) => Py::new(py, base).map(Into::into), PyDecoderWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyDecoderWrapper::Wrapped(inner) => match inner.as_ref() { PyDecoderWrapper::Wrapped(inner) => match inner.as_ref() {
DecoderWrapper::Metaspace(_) => { DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?.into_py(py),
Py::new(py, (PyMetaspaceDec {}, base)).map(Into::into) DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
} DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::WordPiece(_) => { DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
Py::new(py, (PyWordPieceDec {}, base)).map(Into::into)
}
DecoderWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevelDec {}, base)).map(Into::into)
}
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base)).map(Into::into),
}, },
} })
} }
} }
@ -65,7 +59,7 @@ impl PyDecoder {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.decoder).map_err(|e| { let data = serde_json::to_string(&self.decoder).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to pickle Decoder: {}", "Error while attempting to pickle Decoder: {}",
e e
)) ))
@ -77,7 +71,7 @@ impl PyDecoder {
match state.extract::<&PyBytes>(py) { match state.extract::<&PyBytes>(py) {
Ok(s) => { Ok(s) => {
self.decoder = serde_json::from_slice(s.as_bytes()).map_err(|e| { self.decoder = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Decoder: {}", "Error while attempting to unpickle Decoder: {}",
e e
)) ))
@ -143,7 +137,7 @@ impl PyMetaspaceDec {
"replacement" => { "replacement" => {
let s: &str = value.extract()?; let s: &str = value.extract()?;
replacement = s.chars().next().ok_or_else(|| { replacement = s.chars().next().ok_or_else(|| {
exceptions::Exception::py_err("replacement must be a character") exceptions::PyValueError::new_err("replacement must be a character")
})?; })?;
} }
"add_prefix_space" => add_prefix_space = value.extract()?, "add_prefix_space" => add_prefix_space = value.extract()?,
@ -202,9 +196,7 @@ impl Decoder for CustomDecoder {
Ok(res) => Ok(res Ok(res) => Ok(res
.cast_as::<PyString>(py) .cast_as::<PyString>(py)
.map_err(|_| PyError::from("`decode` is expected to return a str"))? .map_err(|_| PyError::from("`decode` is expected to return a str"))?
.to_string() .to_string()),
.map_err(|_| PyError::from("`decode` is expected to return a str"))?
.into_owned()),
Err(e) => { Err(e) => {
e.print(py); e.print(py);
Err(Box::new(PyError::from("Error while calling `decode`"))) Err(Box::new(PyError::from("Error while calling `decode`")))
@ -273,7 +265,7 @@ impl Decoder for PyDecoderWrapper {
mod test { mod test {
use std::sync::Arc; use std::sync::Arc;
use pyo3::{AsPyRef, Py, PyObject, Python}; use pyo3::prelude::*;
use tk::decoders::metaspace::Metaspace; use tk::decoders::metaspace::Metaspace;
use tk::decoders::DecoderWrapper; use tk::decoders::DecoderWrapper;
@ -306,8 +298,9 @@ mod test {
_ => panic!("Expected wrapped, not custom."), _ => panic!("Expected wrapped, not custom."),
} }
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
let py = gil.python();
let py_msp = PyDecoder::new(Metaspace::default().into()); let py_msp = PyDecoder::new(Metaspace::default().into());
let obj: PyObject = Py::new(gil.python(), py_msp).unwrap().into(); let obj: PyObject = Py::new(py, py_msp).unwrap().into_py(py);
let py_seq = PyDecoderWrapper::Custom(Arc::new(CustomDecoder::new(obj).unwrap())); let py_seq = PyDecoderWrapper::Custom(Arc::new(CustomDecoder::new(obj).unwrap()));
assert!(serde_json::to_string(&py_seq).is_err()); assert!(serde_json::to_string(&py_seq).is_err());
} }

View File

@ -48,7 +48,7 @@ impl PyEncoding {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.encoding).map_err(|e| { let data = serde_json::to_string(&self.encoding).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to pickle Encoding: {}", "Error while attempting to pickle Encoding: {}",
e.to_string() e.to_string()
)) ))
@ -60,7 +60,7 @@ impl PyEncoding {
match state.extract::<&PyBytes>(py) { match state.extract::<&PyBytes>(py) {
Ok(s) => { Ok(s) => {
self.encoding = serde_json::from_slice(s.as_bytes()).map_err(|e| { self.encoding = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Encoding: {}", "Error while attempting to unpickle Encoding: {}",
e.to_string() e.to_string()
)) ))
@ -171,7 +171,7 @@ impl PyEncoding {
one of `left` or `right`", one of `left` or `right`",
other other
)) ))
.into_pyerr()), .into_pyerr::<exceptions::PyValueError>()),
}?; }?;
} }
"pad_id" => pad_id = value.extract()?, "pad_id" => pad_id = value.extract()?,

View File

@ -1,5 +1,6 @@
use pyo3::exceptions; use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::type_object::PyTypeObject;
use std::fmt::{Display, Formatter, Result as FmtResult}; use std::fmt::{Display, Formatter, Result as FmtResult};
use tokenizers::tokenizer::Result; use tokenizers::tokenizer::Result;
@ -9,8 +10,8 @@ impl PyError {
pub fn from(s: &str) -> Self { pub fn from(s: &str) -> Self {
PyError(String::from(s)) PyError(String::from(s))
} }
pub fn into_pyerr(self) -> PyErr { pub fn into_pyerr<T: PyTypeObject>(self) -> PyErr {
exceptions::Exception::py_err(format!("{}", self)) PyErr::new::<T, _>(format!("{}", self))
} }
} }
impl Display for PyError { impl Display for PyError {
@ -24,7 +25,7 @@ pub struct ToPyResult<T>(pub Result<T>);
impl<T> std::convert::Into<PyResult<T>> for ToPyResult<T> { impl<T> std::convert::Into<PyResult<T>> for ToPyResult<T> {
fn into(self) -> PyResult<T> { fn into(self) -> PyResult<T> {
self.0 self.0
.map_err(|e| exceptions::Exception::py_err(format!("{}", e))) .map_err(|e| exceptions::PyException::new_err(format!("{}", e)))
} }
} }
impl<T> ToPyResult<T> { impl<T> ToPyResult<T> {

View File

@ -34,12 +34,12 @@ impl PyModel {
let base = self.clone(); let base = self.clone();
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
let py = gil.python(); let py = gil.python();
match self.model.as_ref() { Ok(match self.model.as_ref() {
ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base)).map(Into::into), ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?.into_py(py),
ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base)).map(Into::into), ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?.into_py(py),
ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base)).map(Into::into), ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?.into_py(py),
ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base)).map(Into::into), ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base))?.into_py(py),
} })
} }
} }
@ -82,7 +82,7 @@ impl PyModel {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.model).map_err(|e| { let data = serde_json::to_string(&self.model).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to pickle Model: {}", "Error while attempting to pickle Model: {}",
e.to_string() e.to_string()
)) ))
@ -94,7 +94,7 @@ impl PyModel {
match state.extract::<&PyBytes>(py) { match state.extract::<&PyBytes>(py) {
Ok(s) => { Ok(s) => {
self.model = serde_json::from_slice(s.as_bytes()).map_err(|e| { self.model = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Model: {}", "Error while attempting to unpickle Model: {}",
e.to_string() e.to_string()
)) ))
@ -130,7 +130,7 @@ impl PyBPE {
kwargs: Option<&PyDict>, kwargs: Option<&PyDict>,
) -> PyResult<(Self, PyModel)> { ) -> PyResult<(Self, PyModel)> {
if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) { if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) {
return Err(exceptions::ValueError::py_err( return Err(exceptions::PyValueError::new_err(
"`vocab` and `merges` must be both specified", "`vocab` and `merges` must be both specified",
)); ));
} }
@ -164,7 +164,7 @@ impl PyBPE {
} }
match builder.build() { match builder.build() {
Err(e) => Err(exceptions::Exception::py_err(format!( Err(e) => Err(exceptions::PyException::new_err(format!(
"Error while initializing BPE: {}", "Error while initializing BPE: {}",
e e
))), ))),
@ -207,12 +207,10 @@ impl PyWordPiece {
} }
match builder.build() { match builder.build() {
Err(e) => { Err(e) => Err(exceptions::PyException::new_err(format!(
println!("Errors: {:?}", e); "Error while initializing WordPiece: {}",
Err(exceptions::Exception::py_err( e
"Error while initializing WordPiece", ))),
))
}
Ok(wordpiece) => Ok((PyWordPiece {}, PyModel::new(Arc::new(wordpiece.into())))), Ok(wordpiece) => Ok((PyWordPiece {}, PyModel::new(Arc::new(wordpiece.into())))),
} }
} }
@ -240,12 +238,10 @@ impl PyWordLevel {
if let Some(vocab) = vocab { if let Some(vocab) = vocab {
match WordLevel::from_files(vocab, unk_token) { match WordLevel::from_files(vocab, unk_token) {
Err(e) => { Err(e) => Err(exceptions::PyException::new_err(format!(
println!("Errors: {:?}", e); "Error while initializing WordLevel: {}",
Err(exceptions::Exception::py_err( e
"Error while initializing WordLevel", ))),
))
}
Ok(model) => Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into())))), Ok(model) => Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into())))),
} }
} else { } else {
@ -263,13 +259,13 @@ pub struct PyUnigram {}
#[pymethods] #[pymethods]
impl PyUnigram { impl PyUnigram {
#[new] #[new]
fn new(vocab: Option<String>) -> PyResult<(Self, PyModel)> { fn new(vocab: Option<&str>) -> PyResult<(Self, PyModel)> {
match vocab { match vocab {
Some(vocab) => match Unigram::load(&std::path::Path::new(&vocab)) { Some(vocab) => match Unigram::load(vocab) {
Err(e) => { Err(e) => Err(exceptions::PyException::new_err(format!(
println!("Errors: {:?}", e); "Error while loading Unigram: {}",
Err(exceptions::Exception::py_err("Error while loading Unigram")) e
} ))),
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))), Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
}, },
None => Ok(( None => Ok((
@ -283,7 +279,7 @@ impl PyUnigram {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::models::PyModel; use crate::models::PyModel;
use pyo3::{AsPyRef, Python}; use pyo3::prelude::*;
use std::sync::Arc; use std::sync::Arc;
use tk::models::bpe::BPE; use tk::models::bpe::BPE;
use tk::models::ModelWrapper; use tk::models::ModelWrapper;

View File

@ -29,35 +29,31 @@ impl PyNormalizer {
let base = self.clone(); let base = self.clone();
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
let py = gil.python(); let py = gil.python();
match self.normalizer { Ok(match self.normalizer {
PyNormalizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base)).map(Into::into), PyNormalizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
PyNormalizerWrapper::Wrapped(ref inner) => match inner.as_ref() { PyNormalizerWrapper::Wrapped(ref inner) => match inner.as_ref() {
NormalizerWrapper::Sequence(_) => { NormalizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
Py::new(py, (PySequence {}, base)).map(Into::into)
}
NormalizerWrapper::BertNormalizer(_) => { NormalizerWrapper::BertNormalizer(_) => {
Py::new(py, (PyBertNormalizer {}, base)).map(Into::into) Py::new(py, (PyBertNormalizer {}, base))?.into_py(py)
} }
NormalizerWrapper::StripNormalizer(_) => { NormalizerWrapper::StripNormalizer(_) => {
Py::new(py, (PyBertNormalizer {}, base)).map(Into::into) Py::new(py, (PyBertNormalizer {}, base))?.into_py(py)
} }
NormalizerWrapper::StripAccents(_) => { NormalizerWrapper::StripAccents(_) => {
Py::new(py, (PyStripAccents {}, base)).map(Into::into) Py::new(py, (PyStripAccents {}, base))?.into_py(py)
}
NormalizerWrapper::NFC(_) => Py::new(py, (PyNFC {}, base)).map(Into::into),
NormalizerWrapper::NFD(_) => Py::new(py, (PyNFD {}, base)).map(Into::into),
NormalizerWrapper::NFKC(_) => Py::new(py, (PyNFKC {}, base)).map(Into::into),
NormalizerWrapper::NFKD(_) => Py::new(py, (PyNFKD {}, base)).map(Into::into),
NormalizerWrapper::Lowercase(_) => {
Py::new(py, (PyLowercase {}, base)).map(Into::into)
} }
NormalizerWrapper::NFC(_) => Py::new(py, (PyNFC {}, base))?.into_py(py),
NormalizerWrapper::NFD(_) => Py::new(py, (PyNFD {}, base))?.into_py(py),
NormalizerWrapper::NFKC(_) => Py::new(py, (PyNFKC {}, base))?.into_py(py),
NormalizerWrapper::NFKD(_) => Py::new(py, (PyNFKD {}, base))?.into_py(py),
NormalizerWrapper::Lowercase(_) => Py::new(py, (PyLowercase {}, base))?.into_py(py),
NormalizerWrapper::Precompiled(_) => { NormalizerWrapper::Precompiled(_) => {
Py::new(py, (PyPrecompiled {}, base)).map(Into::into) Py::new(py, (PyPrecompiled {}, base))?.into_py(py)
} }
NormalizerWrapper::Replace(_) => Py::new(py, (PyReplace {}, base)).map(Into::into), NormalizerWrapper::Replace(_) => Py::new(py, (PyReplace {}, base))?.into_py(py),
NormalizerWrapper::Nmt(_) => Py::new(py, (PyNmt {}, base)).map(Into::into), NormalizerWrapper::Nmt(_) => Py::new(py, (PyNmt {}, base))?.into_py(py),
}, },
} })
} }
} }
@ -71,9 +67,9 @@ impl Normalizer for PyNormalizer {
impl PyNormalizer { impl PyNormalizer {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.normalizer).map_err(|e| { let data = serde_json::to_string(&self.normalizer).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to pickle Normalizer: {}", "Error while attempting to pickle Normalizer: {}",
e.to_string() e
)) ))
})?; })?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py)) Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
@ -83,9 +79,9 @@ impl PyNormalizer {
match state.extract::<&PyBytes>(py) { match state.extract::<&PyBytes>(py) {
Ok(s) => { Ok(s) => {
self.normalizer = serde_json::from_slice(s.as_bytes()).map_err(|e| { self.normalizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Normalizer: {}", "Error while attempting to unpickle Normalizer: {}",
e.to_string() e
)) ))
})?; })?;
Ok(()) Ok(())
@ -315,9 +311,9 @@ impl PyPrecompiled {
PyPrecompiled {}, PyPrecompiled {},
Precompiled::from(precompiled_charsmap) Precompiled::from(precompiled_charsmap)
.map_err(|e| { .map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to build Precompiled normalizer: {}", "Error while attempting to build Precompiled normalizer: {}",
e.to_string() e
)) ))
})? })?
.into(), .into(),
@ -337,7 +333,7 @@ impl PyReplace {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use pyo3::{AsPyRef, Python}; use pyo3::prelude::*;
use tk::normalizers::unicode::{NFC, NFKC}; use tk::normalizers::unicode::{NFC, NFKC};
use tk::normalizers::utils::Sequence; use tk::normalizers::utils::Sequence;
use tk::normalizers::NormalizerWrapper; use tk::normalizers::NormalizerWrapper;

View File

@ -37,39 +37,35 @@ impl PyPreTokenizer {
let base = self.clone(); let base = self.clone();
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
let py = gil.python(); let py = gil.python();
match &self.pretok { Ok(match &self.pretok {
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base).map(Into::into), PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyPreTokenizerWrapper::Sequence(_) => { PyPreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
Py::new(py, (PySequence {}, base)).map(Into::into)
}
PyPreTokenizerWrapper::Wrapped(inner) => match inner.as_ref() { PyPreTokenizerWrapper::Wrapped(inner) => match inner.as_ref() {
PreTokenizerWrapper::Whitespace(_) => { PreTokenizerWrapper::Whitespace(_) => {
Py::new(py, (PyWhitespace {}, base)).map(Into::into) Py::new(py, (PyWhitespace {}, base))?.into_py(py)
} }
PreTokenizerWrapper::Punctuation(_) => { PreTokenizerWrapper::Punctuation(_) => {
Py::new(py, (PyPunctuation {}, base)).map(Into::into) Py::new(py, (PyPunctuation {}, base))?.into_py(py)
}
PreTokenizerWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base)).map(Into::into)
} }
PreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
PreTokenizerWrapper::Metaspace(_) => { PreTokenizerWrapper::Metaspace(_) => {
Py::new(py, (PyMetaspace {}, base)).map(Into::into) Py::new(py, (PyMetaspace {}, base))?.into_py(py)
} }
PreTokenizerWrapper::Delimiter(_) => { PreTokenizerWrapper::Delimiter(_) => {
Py::new(py, (PyCharDelimiterSplit {}, base)).map(Into::into) Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py)
} }
PreTokenizerWrapper::WhitespaceSplit(_) => { PreTokenizerWrapper::WhitespaceSplit(_) => {
Py::new(py, (PyWhitespaceSplit {}, base)).map(Into::into) Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py)
} }
PreTokenizerWrapper::ByteLevel(_) => { PreTokenizerWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevel {}, base)).map(Into::into) Py::new(py, (PyByteLevel {}, base))?.into_py(py)
} }
PreTokenizerWrapper::BertPreTokenizer(_) => { PreTokenizerWrapper::BertPreTokenizer(_) => {
Py::new(py, (PyBertPreTokenizer {}, base)).map(Into::into) Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py)
} }
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base)).map(Into::into), PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py),
}, },
} })
} }
} }
@ -91,7 +87,7 @@ impl PyPreTokenizer {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.pretok).map_err(|e| { let data = serde_json::to_string(&self.pretok).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to pickle PreTokenizer: {}", "Error while attempting to pickle PreTokenizer: {}",
e.to_string() e.to_string()
)) ))
@ -103,9 +99,9 @@ impl PyPreTokenizer {
match state.extract::<&PyBytes>(py) { match state.extract::<&PyBytes>(py) {
Ok(s) => { Ok(s) => {
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| { let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PreTokenizer: {}", "Error while attempting to unpickle PreTokenizer: {}",
e.to_string() e
)) ))
})?; })?;
self.pretok = unpickled; self.pretok = unpickled;
@ -187,10 +183,9 @@ pub struct PyCharDelimiterSplit {}
impl PyCharDelimiterSplit { impl PyCharDelimiterSplit {
#[new] #[new]
pub fn new(delimiter: &str) -> PyResult<(Self, PyPreTokenizer)> { pub fn new(delimiter: &str) -> PyResult<(Self, PyPreTokenizer)> {
let chr_delimiter = delimiter let chr_delimiter = delimiter.chars().next().ok_or_else(|| {
.chars() exceptions::PyValueError::new_err("delimiter must be a single character")
.next() })?;
.ok_or_else(|| exceptions::Exception::py_err("delimiter must be a single character"))?;
Ok(( Ok((
PyCharDelimiterSplit {}, PyCharDelimiterSplit {},
CharDelimiterSplit::new(chr_delimiter).into(), CharDelimiterSplit::new(chr_delimiter).into(),
@ -267,7 +262,7 @@ impl PyMetaspace {
"replacement" => { "replacement" => {
let s: &str = value.extract()?; let s: &str = value.extract()?;
replacement = s.chars().next().ok_or_else(|| { replacement = s.chars().next().ok_or_else(|| {
exceptions::Exception::py_err("replacement must be a character") exceptions::PyValueError::new_err("replacement must be a character")
})?; })?;
} }
"add_prefix_space" => add_prefix_space = value.extract()?, "add_prefix_space" => add_prefix_space = value.extract()?,
@ -417,7 +412,7 @@ impl PreTokenizer for PyPreTokenizerWrapper {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use pyo3::{AsPyRef, Py, PyObject, Python}; use pyo3::prelude::*;
use tk::pre_tokenizers::whitespace::Whitespace; use tk::pre_tokenizers::whitespace::Whitespace;
use tk::pre_tokenizers::PreTokenizerWrapper; use tk::pre_tokenizers::PreTokenizerWrapper;
@ -451,8 +446,9 @@ mod test {
_ => panic!("Expected wrapped, not custom."), _ => panic!("Expected wrapped, not custom."),
} }
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
let py = gil.python();
let py_wsp = PyPreTokenizer::new(Whitespace::default().into()); let py_wsp = PyPreTokenizer::new(Whitespace::default().into());
let obj: PyObject = Py::new(gil.python(), py_wsp).unwrap().into(); let obj: PyObject = Py::new(py, py_wsp).unwrap().into_py(py);
let py_seq: PyPreTokenizerWrapper = let py_seq: PyPreTokenizerWrapper =
PyPreTokenizerWrapper::Custom(Arc::new(CustomPreTokenizer::new(obj).unwrap())); PyPreTokenizerWrapper::Custom(Arc::new(CustomPreTokenizer::new(obj).unwrap()));
assert!(serde_json::to_string(&py_seq).is_err()); assert!(serde_json::to_string(&py_seq).is_err());

View File

@ -30,20 +30,16 @@ impl PyPostProcessor {
let base = self.clone(); let base = self.clone();
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
let py = gil.python(); let py = gil.python();
match self.processor.as_ref() { Ok(match self.processor.as_ref() {
PostProcessorWrapper::ByteLevel(_) => { PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?.into_py(py),
Py::new(py, (PyByteLevel {}, base)).map(Into::into) PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?.into_py(py),
}
PostProcessorWrapper::Bert(_) => {
Py::new(py, (PyBertProcessing {}, base)).map(Into::into)
}
PostProcessorWrapper::Roberta(_) => { PostProcessorWrapper::Roberta(_) => {
Py::new(py, (PyRobertaProcessing {}, base)).map(Into::into) Py::new(py, (PyRobertaProcessing {}, base))?.into_py(py)
} }
PostProcessorWrapper::Template(_) => { PostProcessorWrapper::Template(_) => {
Py::new(py, (PyTemplateProcessing {}, base)).map(Into::into) Py::new(py, (PyTemplateProcessing {}, base))?.into_py(py)
} }
} })
} }
} }
@ -67,7 +63,7 @@ impl PostProcessor for PyPostProcessor {
impl PyPostProcessor { impl PyPostProcessor {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| { let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to pickle PostProcessor: {}", "Error while attempting to pickle PostProcessor: {}",
e.to_string() e.to_string()
)) ))
@ -79,7 +75,7 @@ impl PyPostProcessor {
match state.extract::<&PyBytes>(py) { match state.extract::<&PyBytes>(py) {
Ok(s) => { Ok(s) => {
self.processor = serde_json::from_slice(s.as_bytes()).map_err(|e| { self.processor = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PostProcessor: {}", "Error while attempting to unpickle PostProcessor: {}",
e.to_string() e.to_string()
)) ))
@ -181,11 +177,11 @@ impl FromPyObject<'_> for PySpecialToken {
} else if let Ok(d) = ob.downcast::<PyDict>() { } else if let Ok(d) = ob.downcast::<PyDict>() {
let id = d let id = d
.get_item("id") .get_item("id")
.ok_or_else(|| exceptions::ValueError::py_err("`id` must be specified"))? .ok_or_else(|| exceptions::PyValueError::new_err("`id` must be specified"))?
.extract::<String>()?; .extract::<String>()?;
let ids = d let ids = d
.get_item("ids") .get_item("ids")
.ok_or_else(|| exceptions::ValueError::py_err("`ids` must be specified"))? .ok_or_else(|| exceptions::PyValueError::new_err("`ids` must be specified"))?
.extract::<Vec<u32>>()?; .extract::<Vec<u32>>()?;
let type_ids = d.get_item("type_ids").map_or_else( let type_ids = d.get_item("type_ids").map_or_else(
|| Ok(vec![None; ids.len()]), || Ok(vec![None; ids.len()]),
@ -193,14 +189,14 @@ impl FromPyObject<'_> for PySpecialToken {
)?; )?;
let tokens = d let tokens = d
.get_item("tokens") .get_item("tokens")
.ok_or_else(|| exceptions::ValueError::py_err("`tokens` must be specified"))? .ok_or_else(|| exceptions::PyValueError::new_err("`tokens` must be specified"))?
.extract::<Vec<String>>()?; .extract::<Vec<String>>()?;
Ok(Self( Ok(Self(
ToPyResult(SpecialToken::new(id, ids, type_ids, tokens)).into_py()?, ToPyResult(SpecialToken::new(id, ids, type_ids, tokens)).into_py()?,
)) ))
} else { } else {
Err(exceptions::TypeError::py_err( Err(exceptions::PyTypeError::new_err(
"Expected Union[Tuple[str, int], Tuple[int, str], dict]", "Expected Union[Tuple[str, int], Tuple[int, str], dict]",
)) ))
} }
@ -223,7 +219,7 @@ impl FromPyObject<'_> for PyTemplate {
} else if let Ok(s) = ob.extract::<Vec<&str>>() { } else if let Ok(s) = ob.extract::<Vec<&str>>() {
Ok(Self(s.into())) Ok(Self(s.into()))
} else { } else {
Err(exceptions::TypeError::py_err( Err(exceptions::PyTypeError::new_err(
"Expected Union[str, List[str]]", "Expected Union[str, List[str]]",
)) ))
} }
@ -252,7 +248,7 @@ impl PyTemplateProcessing {
if let Some(sp) = special_tokens { if let Some(sp) = special_tokens {
builder.special_tokens(sp); builder.special_tokens(sp);
} }
let processor = builder.build().map_err(exceptions::ValueError::py_err)?; let processor = builder.build().map_err(exceptions::PyValueError::new_err)?;
Ok(( Ok((
PyTemplateProcessing {}, PyTemplateProcessing {},
@ -265,7 +261,7 @@ impl PyTemplateProcessing {
mod test { mod test {
use std::sync::Arc; use std::sync::Arc;
use pyo3::{AsPyRef, Python}; use pyo3::prelude::*;
use tk::processors::bert::BertProcessing; use tk::processors::bert::BertProcessing;
use tk::processors::PostProcessorWrapper; use tk::processors::PostProcessorWrapper;

View File

@ -175,9 +175,9 @@ impl PyObjectProtocol for PyAddedToken {
struct TextInputSequence<'s>(tk::InputSequence<'s>); struct TextInputSequence<'s>(tk::InputSequence<'s>);
impl<'s> FromPyObject<'s> for TextInputSequence<'s> { impl<'s> FromPyObject<'s> for TextInputSequence<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> { fn extract(ob: &'s PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err("TextInputSequence must be str"); let err = exceptions::PyTypeError::new_err("TextInputSequence must be str");
if let Ok(s) = ob.downcast::<PyString>() { if let Ok(s) = ob.downcast::<PyString>() {
Ok(Self(s.to_string().map_err(|_| err)?.into())) Ok(Self(s.to_string_lossy().into()))
} else { } else {
Err(err) Err(err)
} }
@ -207,7 +207,9 @@ impl FromPyObject<'_> for PyArrayUnicode {
// type_num == 19 => Unicode // type_num == 19 => Unicode
if type_num != 19 { if type_num != 19 {
return Err(exceptions::TypeError::py_err("Expected a np.array[str]")); return Err(exceptions::PyTypeError::new_err(
"Expected a np.array[dtype='U']",
));
} }
unsafe { unsafe {
@ -224,7 +226,7 @@ impl FromPyObject<'_> for PyArrayUnicode {
let py = gil.python(); let py = gil.python();
let obj = PyObject::from_owned_ptr(py, unicode); let obj = PyObject::from_owned_ptr(py, unicode);
let s = obj.cast_as::<PyString>(py)?; let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string()?.trim_matches(char::from(0)).to_owned()) Ok(s.to_string_lossy().trim_matches(char::from(0)).to_owned())
}) })
.collect::<PyResult<Vec<_>>>()?; .collect::<PyResult<Vec<_>>>()?;
@ -247,7 +249,9 @@ impl FromPyObject<'_> for PyArrayStr {
let n_elem = array.shape()[0]; let n_elem = array.shape()[0];
if type_num != 17 { if type_num != 17 {
return Err(exceptions::TypeError::py_err("Expected a np.array[str]")); return Err(exceptions::PyTypeError::new_err(
"Expected a np.array[dtype='O']",
));
} }
unsafe { unsafe {
@ -259,7 +263,7 @@ impl FromPyObject<'_> for PyArrayStr {
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
let py = gil.python(); let py = gil.python();
let s = obj.cast_as::<PyString>(py)?; let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string()?.into_owned()) Ok(s.to_string_lossy().into_owned())
}) })
.collect::<PyResult<Vec<_>>>()?; .collect::<PyResult<Vec<_>>>()?;
@ -292,7 +296,7 @@ impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
return Ok(Self(seq.into())); return Ok(Self(seq.into()));
} }
} }
Err(exceptions::ValueError::py_err( Err(exceptions::PyTypeError::new_err(
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]", "PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
)) ))
} }
@ -319,7 +323,7 @@ impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
return Ok(Self((first, second).into())); return Ok(Self((first, second).into()));
} }
} }
Err(exceptions::ValueError::py_err( Err(exceptions::PyTypeError::new_err(
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]", "TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
)) ))
} }
@ -346,7 +350,7 @@ impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
return Ok(Self((first, second).into())); return Ok(Self((first, second).into()));
} }
} }
Err(exceptions::ValueError::py_err( Err(exceptions::PyTypeError::new_err(
"PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \ "PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]", Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]",
)) ))
@ -385,9 +389,9 @@ impl PyTokenizer {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.tokenizer).map_err(|e| { let data = serde_json::to_string(&self.tokenizer).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to pickle Tokenizer: {}", "Error while attempting to pickle Tokenizer: {}",
e.to_string() e
)) ))
})?; })?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py)) Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
@ -397,9 +401,9 @@ impl PyTokenizer {
match state.extract::<&PyBytes>(py) { match state.extract::<&PyBytes>(py) {
Ok(s) => { Ok(s) => {
self.tokenizer = serde_json::from_slice(s.as_bytes()).map_err(|e| { self.tokenizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Tokenizer: {}", "Error while attempting to unpickle Tokenizer: {}",
e.to_string() e
)) ))
})?; })?;
Ok(()) Ok(())
@ -429,9 +433,9 @@ impl PyTokenizer {
#[staticmethod] #[staticmethod]
fn from_buffer(buffer: &PyBytes) -> PyResult<Self> { fn from_buffer(buffer: &PyBytes) -> PyResult<Self> {
let tokenizer = serde_json::from_slice(buffer.as_bytes()).map_err(|e| { let tokenizer = serde_json::from_slice(buffer.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::PyValueError::new_err(format!(
"Cannot instantiate Tokenizer from buffer: {}", "Cannot instantiate Tokenizer from buffer: {}",
e.to_string() e
)) ))
})?; })?;
Ok(Self { tokenizer }) Ok(Self { tokenizer })
@ -485,7 +489,7 @@ impl PyTokenizer {
one of `longest_first`, `only_first`, or `only_second`", one of `longest_first`, `only_first`, or `only_second`",
value value
)) ))
.into_pyerr()), .into_pyerr::<exceptions::PyValueError>()),
}? }?
} }
_ => println!("Ignored unknown kwarg option {}", key), _ => println!("Ignored unknown kwarg option {}", key),
@ -533,7 +537,7 @@ impl PyTokenizer {
one of `left` or `right`", one of `left` or `right`",
other other
)) ))
.into_pyerr()), .into_pyerr::<exceptions::PyValueError>()),
}?; }?;
} }
"pad_to_multiple_of" => { "pad_to_multiple_of" => {
@ -716,7 +720,7 @@ impl PyTokenizer {
token.is_special_token = false; token.is_special_token = false;
Ok(token.get_token()) Ok(token.get_token())
} else { } else {
Err(exceptions::Exception::py_err( Err(exceptions::PyTypeError::new_err(
"Input must be a List[Union[str, AddedToken]]", "Input must be a List[Union[str, AddedToken]]",
)) ))
} }
@ -736,7 +740,7 @@ impl PyTokenizer {
token.is_special_token = true; token.is_special_token = true;
Ok(token.get_token()) Ok(token.get_token())
} else { } else {
Err(exceptions::Exception::py_err( Err(exceptions::PyTypeError::new_err(
"Input must be a List[Union[str, AddedToken]]", "Input must be a List[Union[str, AddedToken]]",
)) ))
} }
@ -747,10 +751,7 @@ impl PyTokenizer {
} }
fn train(&mut self, trainer: &PyTrainer, files: Vec<String>) -> PyResult<()> { fn train(&mut self, trainer: &PyTrainer, files: Vec<String>) -> PyResult<()> {
self.tokenizer ToPyResult(self.tokenizer.train_and_replace(trainer, files)).into()
.train_and_replace(trainer, files)
.map_err(|e| exceptions::Exception::py_err(format!("{}", e)))?;
Ok(())
} }
#[args(pair = "None", add_special_tokens = true)] #[args(pair = "None", add_special_tokens = true)]

View File

@ -73,7 +73,7 @@ impl PyBpeTrainer {
token.is_special_token = true; token.is_special_token = true;
Ok(token.get_token()) Ok(token.get_token())
} else { } else {
Err(exceptions::Exception::py_err( Err(exceptions::PyTypeError::new_err(
"special_tokens must be a List[Union[str, AddedToken]]", "special_tokens must be a List[Union[str, AddedToken]]",
)) ))
} }
@ -137,7 +137,7 @@ impl PyWordPieceTrainer {
token.is_special_token = true; token.is_special_token = true;
Ok(token.get_token()) Ok(token.get_token())
} else { } else {
Err(exceptions::Exception::py_err( Err(exceptions::PyTypeError::new_err(
"special_tokens must be a List[Union[str, AddedToken]]", "special_tokens must be a List[Union[str, AddedToken]]",
)) ))
} }
@ -205,7 +205,7 @@ impl PyUnigramTrainer {
token.is_special_token = true; token.is_special_token = true;
Ok(token.get_token()) Ok(token.get_token())
} else { } else {
Err(exceptions::Exception::py_err( Err(exceptions::PyTypeError::new_err(
"special_tokens must be a List[Union[str, AddedToken]]", "special_tokens must be a List[Union[str, AddedToken]]",
)) ))
} }
@ -220,9 +220,10 @@ impl PyUnigramTrainer {
} }
} }
let trainer: tokenizers::models::unigram::UnigramTrainer = builder let trainer: tokenizers::models::unigram::UnigramTrainer =
.build() builder.build().map_err(|e| {
.map_err(|_| exceptions::Exception::py_err("Cannot build UnigramTrainer"))?; exceptions::PyException::new_err(format!("Cannot build UnigramTrainer: {}", e))
})?;
Ok((PyUnigramTrainer {}, PyTrainer::new(trainer.into()))) Ok((PyUnigramTrainer {}, PyTrainer::new(trainer.into())))
} }
} }

View File

@ -238,12 +238,12 @@ class TestTokenizer:
) )
# Mal formed # Mal formed
with pytest.raises(ValueError, match="InputSequence must be str"): with pytest.raises(TypeError, match="TextInputSequence must be str"):
tokenizer.encode([["my", "name"]]) tokenizer.encode([["my", "name"]])
tokenizer.encode("My name is john", [["pair"]]) tokenizer.encode("My name is john", [["pair"]])
tokenizer.encode("my name is john", ["pair"]) tokenizer.encode("my name is john", ["pair"])
with pytest.raises(ValueError, match="InputSequence must be Union[List[str]"): with pytest.raises(TypeError, match="InputSequence must be Union[List[str]"):
tokenizer.encode("My name is john", is_pretokenized=True) tokenizer.encode("My name is john", is_pretokenized=True)
tokenizer.encode("My name is john", ["pair"], is_pretokenized=True) tokenizer.encode("My name is john", ["pair"], is_pretokenized=True)
tokenizer.encode(["My", "name", "is", "John"], "pair", is_pretokenized=True) tokenizer.encode(["My", "name", "is", "John"], "pair", is_pretokenized=True)

View File

@ -366,9 +366,9 @@ impl Unigram {
/// use tokenizers::models::unigram::Unigram; /// use tokenizers::models::unigram::Unigram;
/// use std::path::Path; /// use std::path::Path;
/// ///
/// let model = Unigram::load(Path::new("mymodel-unigram.json")).unwrap(); /// let model = Unigram::load("mymodel-unigram.json").unwrap();
/// ``` /// ```
pub fn load(path: &Path) -> Result<Unigram> { pub fn load<P: AsRef<Path>>(path: P) -> Result<Unigram> {
let file = File::open(path).unwrap(); let file = File::open(path).unwrap();
let reader = BufReader::new(file); let reader = BufReader::new(file);
let u = serde_json::from_reader(reader)?; let u = serde_json::from_reader(reader)?;