Update PyO3 (#426)

This commit is contained in:
Anthony MOI
2020-09-22 12:00:20 -04:00
committed by GitHub
parent 8e220dbdd4
commit 940f8bd8fa
13 changed files with 156 additions and 178 deletions

View File

@ -484,14 +484,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "numpy"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01#e331befa27fede78d4662edf08fa0508db39be01"
dependencies = [
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)",
"ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)",
"num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@ -580,7 +580,7 @@ dependencies = [
[[package]]
name = "pyo3"
version = "0.11.1"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"ctor 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)",
@ -589,13 +589,13 @@ dependencies = [
"libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)",
"parking_lot 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)",
"paste 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3cls 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3cls 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
"unindent 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "pyo3-derive-backend"
version = "0.11.1"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro2 1.0.21 (registry+https://github.com/rust-lang/crates.io-index)",
@ -605,10 +605,10 @@ dependencies = [
[[package]]
name = "pyo3cls"
version = "0.11.1"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"pyo3-derive-backend 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3-derive-backend 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -895,8 +895,8 @@ dependencies = [
"env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)",
"ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)",
"numpy 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
"numpy 0.11.0 (git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01)",
"pyo3 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)",
@ -1037,7 +1037,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
"checksum num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
"checksum numpy 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1afa393812b0e4c33d8a46054028ce39d79f954dfa98298f25691abf00b24e39"
"checksum numpy 0.11.0 (git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01)" = "<none>"
"checksum onig 6.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8a155d13862da85473665694f4c05d77fb96598bdceeaf696433c84ea9567e20"
"checksum onig_sys 69.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "9bff06597a6b17855040955cae613af000fc0bfc8ad49ea68b9479a74e59292d"
"checksum parking_lot 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a4893845fa2ca272e647da5d0e46660a314ead9c2fdd9a883aabc32e481a8733"
@ -1048,9 +1048,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum ppv-lite86 0.2.9 (registry+https://github.com/rust-lang/crates.io-index)" = "c36fa947111f5c62a733b652544dd0016a43ce89619538a8ef92724a6f501a20"
"checksum proc-macro-hack 0.5.18 (registry+https://github.com/rust-lang/crates.io-index)" = "99c605b9a0adc77b7211c6b1f722dcb613d68d66859a44f3d485a6da332b0598"
"checksum proc-macro2 1.0.21 (registry+https://github.com/rust-lang/crates.io-index)" = "36e28516df94f3dd551a587da5357459d9b36d945a7c37c3557928c1c2ff2a2c"
"checksum pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "9ca8710ffa8211c9a62a8a3863c4267c710dc42a82a7fd29c97de465d7ea6b7d"
"checksum pyo3-derive-backend 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "58ad070bf6967b0d29ea74931ffcf9c6bbe8402a726e9afbeafadc0a287cc2b3"
"checksum pyo3cls 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c3fa17e1ea569d0bf3b7c00f2a9eea831ca05e55dd76f1794c541abba1c64baa"
"checksum pyo3 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4fe7ed6655de90fd3bf799f25ce8b695e89eeb098d03db802723617649fcaa74"
"checksum pyo3-derive-backend 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d021655b5f22aeee2eaf3c3f88503250ad6cf90e37a875e1bb1854fceb86b3d0"
"checksum pyo3cls 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "228ee516c00ac54fb70032d8cd928f213e70c87c340cdf44e2eba028d1f613ea"
"checksum quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
"checksum quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)" = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37"
"checksum rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"

View File

@ -14,12 +14,10 @@ serde = { version = "1.0", features = [ "rc", "derive" ]}
serde_json = "1.0"
libc = "0.2"
env_logger = "0.7.1"
numpy = "0.11"
pyo3 = "0.12"
numpy = { git = "https://github.com/pyo3/rust-numpy/", rev = "e331befa27fede78d4662edf08fa0508db39be01" }
ndarray = "0.13"
[dependencies.pyo3]
version = "0.11"
[dependencies.tokenizers]
version = "*"
path = "../../tokenizers"

View File

@ -31,21 +31,15 @@ impl PyDecoder {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
match &self.decoder {
PyDecoderWrapper::Custom(_) => Py::new(py, base).map(Into::into),
Ok(match &self.decoder {
PyDecoderWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyDecoderWrapper::Wrapped(inner) => match inner.as_ref() {
DecoderWrapper::Metaspace(_) => {
Py::new(py, (PyMetaspaceDec {}, base)).map(Into::into)
}
DecoderWrapper::WordPiece(_) => {
Py::new(py, (PyWordPieceDec {}, base)).map(Into::into)
}
DecoderWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevelDec {}, base)).map(Into::into)
}
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base)).map(Into::into),
DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?.into_py(py),
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
},
}
})
}
}
@ -65,7 +59,7 @@ impl PyDecoder {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.decoder).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to pickle Decoder: {}",
e
))
@ -77,7 +71,7 @@ impl PyDecoder {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.decoder = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Decoder: {}",
e
))
@ -143,7 +137,7 @@ impl PyMetaspaceDec {
"replacement" => {
let s: &str = value.extract()?;
replacement = s.chars().next().ok_or_else(|| {
exceptions::Exception::py_err("replacement must be a character")
exceptions::PyValueError::new_err("replacement must be a character")
})?;
}
"add_prefix_space" => add_prefix_space = value.extract()?,
@ -202,9 +196,7 @@ impl Decoder for CustomDecoder {
Ok(res) => Ok(res
.cast_as::<PyString>(py)
.map_err(|_| PyError::from("`decode` is expected to return a str"))?
.to_string()
.map_err(|_| PyError::from("`decode` is expected to return a str"))?
.into_owned()),
.to_string()),
Err(e) => {
e.print(py);
Err(Box::new(PyError::from("Error while calling `decode`")))
@ -273,7 +265,7 @@ impl Decoder for PyDecoderWrapper {
mod test {
use std::sync::Arc;
use pyo3::{AsPyRef, Py, PyObject, Python};
use pyo3::prelude::*;
use tk::decoders::metaspace::Metaspace;
use tk::decoders::DecoderWrapper;
@ -306,8 +298,9 @@ mod test {
_ => panic!("Expected wrapped, not custom."),
}
let gil = Python::acquire_gil();
let py = gil.python();
let py_msp = PyDecoder::new(Metaspace::default().into());
let obj: PyObject = Py::new(gil.python(), py_msp).unwrap().into();
let obj: PyObject = Py::new(py, py_msp).unwrap().into_py(py);
let py_seq = PyDecoderWrapper::Custom(Arc::new(CustomDecoder::new(obj).unwrap()));
assert!(serde_json::to_string(&py_seq).is_err());
}

View File

@ -48,7 +48,7 @@ impl PyEncoding {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.encoding).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to pickle Encoding: {}",
e.to_string()
))
@ -60,7 +60,7 @@ impl PyEncoding {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.encoding = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Encoding: {}",
e.to_string()
))
@ -171,7 +171,7 @@ impl PyEncoding {
one of `left` or `right`",
other
))
.into_pyerr()),
.into_pyerr::<exceptions::PyValueError>()),
}?;
}
"pad_id" => pad_id = value.extract()?,

View File

@ -1,5 +1,6 @@
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::type_object::PyTypeObject;
use std::fmt::{Display, Formatter, Result as FmtResult};
use tokenizers::tokenizer::Result;
@ -9,8 +10,8 @@ impl PyError {
pub fn from(s: &str) -> Self {
PyError(String::from(s))
}
pub fn into_pyerr(self) -> PyErr {
exceptions::Exception::py_err(format!("{}", self))
pub fn into_pyerr<T: PyTypeObject>(self) -> PyErr {
PyErr::new::<T, _>(format!("{}", self))
}
}
impl Display for PyError {
@ -24,7 +25,7 @@ pub struct ToPyResult<T>(pub Result<T>);
impl<T> std::convert::Into<PyResult<T>> for ToPyResult<T> {
fn into(self) -> PyResult<T> {
self.0
.map_err(|e| exceptions::Exception::py_err(format!("{}", e)))
.map_err(|e| exceptions::PyException::new_err(format!("{}", e)))
}
}
impl<T> ToPyResult<T> {

View File

@ -34,12 +34,12 @@ impl PyModel {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
match self.model.as_ref() {
ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base)).map(Into::into),
ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base)).map(Into::into),
ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base)).map(Into::into),
ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base)).map(Into::into),
}
Ok(match self.model.as_ref() {
ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?.into_py(py),
ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?.into_py(py),
ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?.into_py(py),
ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base))?.into_py(py),
})
}
}
@ -82,7 +82,7 @@ impl PyModel {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.model).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to pickle Model: {}",
e.to_string()
))
@ -94,7 +94,7 @@ impl PyModel {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.model = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Model: {}",
e.to_string()
))
@ -130,7 +130,7 @@ impl PyBPE {
kwargs: Option<&PyDict>,
) -> PyResult<(Self, PyModel)> {
if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) {
return Err(exceptions::ValueError::py_err(
return Err(exceptions::PyValueError::new_err(
"`vocab` and `merges` must be both specified",
));
}
@ -164,7 +164,7 @@ impl PyBPE {
}
match builder.build() {
Err(e) => Err(exceptions::Exception::py_err(format!(
Err(e) => Err(exceptions::PyException::new_err(format!(
"Error while initializing BPE: {}",
e
))),
@ -207,12 +207,10 @@ impl PyWordPiece {
}
match builder.build() {
Err(e) => {
println!("Errors: {:?}", e);
Err(exceptions::Exception::py_err(
"Error while initializing WordPiece",
))
}
Err(e) => Err(exceptions::PyException::new_err(format!(
"Error while initializing WordPiece: {}",
e
))),
Ok(wordpiece) => Ok((PyWordPiece {}, PyModel::new(Arc::new(wordpiece.into())))),
}
}
@ -240,12 +238,10 @@ impl PyWordLevel {
if let Some(vocab) = vocab {
match WordLevel::from_files(vocab, unk_token) {
Err(e) => {
println!("Errors: {:?}", e);
Err(exceptions::Exception::py_err(
"Error while initializing WordLevel",
))
}
Err(e) => Err(exceptions::PyException::new_err(format!(
"Error while initializing WordLevel: {}",
e
))),
Ok(model) => Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into())))),
}
} else {
@ -263,13 +259,13 @@ pub struct PyUnigram {}
#[pymethods]
impl PyUnigram {
#[new]
fn new(vocab: Option<String>) -> PyResult<(Self, PyModel)> {
fn new(vocab: Option<&str>) -> PyResult<(Self, PyModel)> {
match vocab {
Some(vocab) => match Unigram::load(&std::path::Path::new(&vocab)) {
Err(e) => {
println!("Errors: {:?}", e);
Err(exceptions::Exception::py_err("Error while loading Unigram"))
}
Some(vocab) => match Unigram::load(vocab) {
Err(e) => Err(exceptions::PyException::new_err(format!(
"Error while loading Unigram: {}",
e
))),
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
},
None => Ok((
@ -283,7 +279,7 @@ impl PyUnigram {
#[cfg(test)]
mod test {
use crate::models::PyModel;
use pyo3::{AsPyRef, Python};
use pyo3::prelude::*;
use std::sync::Arc;
use tk::models::bpe::BPE;
use tk::models::ModelWrapper;

View File

@ -29,35 +29,31 @@ impl PyNormalizer {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
match self.normalizer {
PyNormalizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base)).map(Into::into),
Ok(match self.normalizer {
PyNormalizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
PyNormalizerWrapper::Wrapped(ref inner) => match inner.as_ref() {
NormalizerWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base)).map(Into::into)
}
NormalizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
NormalizerWrapper::BertNormalizer(_) => {
Py::new(py, (PyBertNormalizer {}, base)).map(Into::into)
Py::new(py, (PyBertNormalizer {}, base))?.into_py(py)
}
NormalizerWrapper::StripNormalizer(_) => {
Py::new(py, (PyBertNormalizer {}, base)).map(Into::into)
Py::new(py, (PyBertNormalizer {}, base))?.into_py(py)
}
NormalizerWrapper::StripAccents(_) => {
Py::new(py, (PyStripAccents {}, base)).map(Into::into)
}
NormalizerWrapper::NFC(_) => Py::new(py, (PyNFC {}, base)).map(Into::into),
NormalizerWrapper::NFD(_) => Py::new(py, (PyNFD {}, base)).map(Into::into),
NormalizerWrapper::NFKC(_) => Py::new(py, (PyNFKC {}, base)).map(Into::into),
NormalizerWrapper::NFKD(_) => Py::new(py, (PyNFKD {}, base)).map(Into::into),
NormalizerWrapper::Lowercase(_) => {
Py::new(py, (PyLowercase {}, base)).map(Into::into)
Py::new(py, (PyStripAccents {}, base))?.into_py(py)
}
NormalizerWrapper::NFC(_) => Py::new(py, (PyNFC {}, base))?.into_py(py),
NormalizerWrapper::NFD(_) => Py::new(py, (PyNFD {}, base))?.into_py(py),
NormalizerWrapper::NFKC(_) => Py::new(py, (PyNFKC {}, base))?.into_py(py),
NormalizerWrapper::NFKD(_) => Py::new(py, (PyNFKD {}, base))?.into_py(py),
NormalizerWrapper::Lowercase(_) => Py::new(py, (PyLowercase {}, base))?.into_py(py),
NormalizerWrapper::Precompiled(_) => {
Py::new(py, (PyPrecompiled {}, base)).map(Into::into)
Py::new(py, (PyPrecompiled {}, base))?.into_py(py)
}
NormalizerWrapper::Replace(_) => Py::new(py, (PyReplace {}, base)).map(Into::into),
NormalizerWrapper::Nmt(_) => Py::new(py, (PyNmt {}, base)).map(Into::into),
NormalizerWrapper::Replace(_) => Py::new(py, (PyReplace {}, base))?.into_py(py),
NormalizerWrapper::Nmt(_) => Py::new(py, (PyNmt {}, base))?.into_py(py),
},
}
})
}
}
@ -71,9 +67,9 @@ impl Normalizer for PyNormalizer {
impl PyNormalizer {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.normalizer).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to pickle Normalizer: {}",
e.to_string()
e
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
@ -83,9 +79,9 @@ impl PyNormalizer {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.normalizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Normalizer: {}",
e.to_string()
e
))
})?;
Ok(())
@ -315,9 +311,9 @@ impl PyPrecompiled {
PyPrecompiled {},
Precompiled::from(precompiled_charsmap)
.map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to build Precompiled normalizer: {}",
e.to_string()
e
))
})?
.into(),
@ -337,7 +333,7 @@ impl PyReplace {
#[cfg(test)]
mod test {
use pyo3::{AsPyRef, Python};
use pyo3::prelude::*;
use tk::normalizers::unicode::{NFC, NFKC};
use tk::normalizers::utils::Sequence;
use tk::normalizers::NormalizerWrapper;

View File

@ -37,39 +37,35 @@ impl PyPreTokenizer {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
match &self.pretok {
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base).map(Into::into),
PyPreTokenizerWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base)).map(Into::into)
}
Ok(match &self.pretok {
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyPreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
PyPreTokenizerWrapper::Wrapped(inner) => match inner.as_ref() {
PreTokenizerWrapper::Whitespace(_) => {
Py::new(py, (PyWhitespace {}, base)).map(Into::into)
Py::new(py, (PyWhitespace {}, base))?.into_py(py)
}
PreTokenizerWrapper::Punctuation(_) => {
Py::new(py, (PyPunctuation {}, base)).map(Into::into)
}
PreTokenizerWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base)).map(Into::into)
Py::new(py, (PyPunctuation {}, base))?.into_py(py)
}
PreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
PreTokenizerWrapper::Metaspace(_) => {
Py::new(py, (PyMetaspace {}, base)).map(Into::into)
Py::new(py, (PyMetaspace {}, base))?.into_py(py)
}
PreTokenizerWrapper::Delimiter(_) => {
Py::new(py, (PyCharDelimiterSplit {}, base)).map(Into::into)
Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py)
}
PreTokenizerWrapper::WhitespaceSplit(_) => {
Py::new(py, (PyWhitespaceSplit {}, base)).map(Into::into)
Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py)
}
PreTokenizerWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevel {}, base)).map(Into::into)
Py::new(py, (PyByteLevel {}, base))?.into_py(py)
}
PreTokenizerWrapper::BertPreTokenizer(_) => {
Py::new(py, (PyBertPreTokenizer {}, base)).map(Into::into)
Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py)
}
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base)).map(Into::into),
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py),
},
}
})
}
}
@ -91,7 +87,7 @@ impl PyPreTokenizer {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.pretok).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to pickle PreTokenizer: {}",
e.to_string()
))
@ -103,9 +99,9 @@ impl PyPreTokenizer {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PreTokenizer: {}",
e.to_string()
e
))
})?;
self.pretok = unpickled;
@ -187,10 +183,9 @@ pub struct PyCharDelimiterSplit {}
impl PyCharDelimiterSplit {
#[new]
pub fn new(delimiter: &str) -> PyResult<(Self, PyPreTokenizer)> {
let chr_delimiter = delimiter
.chars()
.next()
.ok_or_else(|| exceptions::Exception::py_err("delimiter must be a single character"))?;
let chr_delimiter = delimiter.chars().next().ok_or_else(|| {
exceptions::PyValueError::new_err("delimiter must be a single character")
})?;
Ok((
PyCharDelimiterSplit {},
CharDelimiterSplit::new(chr_delimiter).into(),
@ -267,7 +262,7 @@ impl PyMetaspace {
"replacement" => {
let s: &str = value.extract()?;
replacement = s.chars().next().ok_or_else(|| {
exceptions::Exception::py_err("replacement must be a character")
exceptions::PyValueError::new_err("replacement must be a character")
})?;
}
"add_prefix_space" => add_prefix_space = value.extract()?,
@ -417,7 +412,7 @@ impl PreTokenizer for PyPreTokenizerWrapper {
#[cfg(test)]
mod test {
use pyo3::{AsPyRef, Py, PyObject, Python};
use pyo3::prelude::*;
use tk::pre_tokenizers::whitespace::Whitespace;
use tk::pre_tokenizers::PreTokenizerWrapper;
@ -451,8 +446,9 @@ mod test {
_ => panic!("Expected wrapped, not custom."),
}
let gil = Python::acquire_gil();
let py = gil.python();
let py_wsp = PyPreTokenizer::new(Whitespace::default().into());
let obj: PyObject = Py::new(gil.python(), py_wsp).unwrap().into();
let obj: PyObject = Py::new(py, py_wsp).unwrap().into_py(py);
let py_seq: PyPreTokenizerWrapper =
PyPreTokenizerWrapper::Custom(Arc::new(CustomPreTokenizer::new(obj).unwrap()));
assert!(serde_json::to_string(&py_seq).is_err());

View File

@ -30,20 +30,16 @@ impl PyPostProcessor {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
match self.processor.as_ref() {
PostProcessorWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevel {}, base)).map(Into::into)
}
PostProcessorWrapper::Bert(_) => {
Py::new(py, (PyBertProcessing {}, base)).map(Into::into)
}
Ok(match self.processor.as_ref() {
PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?.into_py(py),
PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?.into_py(py),
PostProcessorWrapper::Roberta(_) => {
Py::new(py, (PyRobertaProcessing {}, base)).map(Into::into)
Py::new(py, (PyRobertaProcessing {}, base))?.into_py(py)
}
PostProcessorWrapper::Template(_) => {
Py::new(py, (PyTemplateProcessing {}, base)).map(Into::into)
Py::new(py, (PyTemplateProcessing {}, base))?.into_py(py)
}
}
})
}
}
@ -67,7 +63,7 @@ impl PostProcessor for PyPostProcessor {
impl PyPostProcessor {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to pickle PostProcessor: {}",
e.to_string()
))
@ -79,7 +75,7 @@ impl PyPostProcessor {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.processor = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PostProcessor: {}",
e.to_string()
))
@ -181,11 +177,11 @@ impl FromPyObject<'_> for PySpecialToken {
} else if let Ok(d) = ob.downcast::<PyDict>() {
let id = d
.get_item("id")
.ok_or_else(|| exceptions::ValueError::py_err("`id` must be specified"))?
.ok_or_else(|| exceptions::PyValueError::new_err("`id` must be specified"))?
.extract::<String>()?;
let ids = d
.get_item("ids")
.ok_or_else(|| exceptions::ValueError::py_err("`ids` must be specified"))?
.ok_or_else(|| exceptions::PyValueError::new_err("`ids` must be specified"))?
.extract::<Vec<u32>>()?;
let type_ids = d.get_item("type_ids").map_or_else(
|| Ok(vec![None; ids.len()]),
@ -193,14 +189,14 @@ impl FromPyObject<'_> for PySpecialToken {
)?;
let tokens = d
.get_item("tokens")
.ok_or_else(|| exceptions::ValueError::py_err("`tokens` must be specified"))?
.ok_or_else(|| exceptions::PyValueError::new_err("`tokens` must be specified"))?
.extract::<Vec<String>>()?;
Ok(Self(
ToPyResult(SpecialToken::new(id, ids, type_ids, tokens)).into_py()?,
))
} else {
Err(exceptions::TypeError::py_err(
Err(exceptions::PyTypeError::new_err(
"Expected Union[Tuple[str, int], Tuple[int, str], dict]",
))
}
@ -223,7 +219,7 @@ impl FromPyObject<'_> for PyTemplate {
} else if let Ok(s) = ob.extract::<Vec<&str>>() {
Ok(Self(s.into()))
} else {
Err(exceptions::TypeError::py_err(
Err(exceptions::PyTypeError::new_err(
"Expected Union[str, List[str]]",
))
}
@ -252,7 +248,7 @@ impl PyTemplateProcessing {
if let Some(sp) = special_tokens {
builder.special_tokens(sp);
}
let processor = builder.build().map_err(exceptions::ValueError::py_err)?;
let processor = builder.build().map_err(exceptions::PyValueError::new_err)?;
Ok((
PyTemplateProcessing {},
@ -265,7 +261,7 @@ impl PyTemplateProcessing {
mod test {
use std::sync::Arc;
use pyo3::{AsPyRef, Python};
use pyo3::prelude::*;
use tk::processors::bert::BertProcessing;
use tk::processors::PostProcessorWrapper;

View File

@ -175,9 +175,9 @@ impl PyObjectProtocol for PyAddedToken {
struct TextInputSequence<'s>(tk::InputSequence<'s>);
impl<'s> FromPyObject<'s> for TextInputSequence<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err("TextInputSequence must be str");
let err = exceptions::PyTypeError::new_err("TextInputSequence must be str");
if let Ok(s) = ob.downcast::<PyString>() {
Ok(Self(s.to_string().map_err(|_| err)?.into()))
Ok(Self(s.to_string_lossy().into()))
} else {
Err(err)
}
@ -207,7 +207,9 @@ impl FromPyObject<'_> for PyArrayUnicode {
// type_num == 19 => Unicode
if type_num != 19 {
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
return Err(exceptions::PyTypeError::new_err(
"Expected a np.array[dtype='U']",
));
}
unsafe {
@ -224,7 +226,7 @@ impl FromPyObject<'_> for PyArrayUnicode {
let py = gil.python();
let obj = PyObject::from_owned_ptr(py, unicode);
let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string()?.trim_matches(char::from(0)).to_owned())
Ok(s.to_string_lossy().trim_matches(char::from(0)).to_owned())
})
.collect::<PyResult<Vec<_>>>()?;
@ -247,7 +249,9 @@ impl FromPyObject<'_> for PyArrayStr {
let n_elem = array.shape()[0];
if type_num != 17 {
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
return Err(exceptions::PyTypeError::new_err(
"Expected a np.array[dtype='O']",
));
}
unsafe {
@ -259,7 +263,7 @@ impl FromPyObject<'_> for PyArrayStr {
let gil = Python::acquire_gil();
let py = gil.python();
let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string()?.into_owned())
Ok(s.to_string_lossy().into_owned())
})
.collect::<PyResult<Vec<_>>>()?;
@ -292,7 +296,7 @@ impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
return Ok(Self(seq.into()));
}
}
Err(exceptions::ValueError::py_err(
Err(exceptions::PyTypeError::new_err(
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
))
}
@ -319,7 +323,7 @@ impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
return Ok(Self((first, second).into()));
}
}
Err(exceptions::ValueError::py_err(
Err(exceptions::PyTypeError::new_err(
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
))
}
@ -346,7 +350,7 @@ impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
return Ok(Self((first, second).into()));
}
}
Err(exceptions::ValueError::py_err(
Err(exceptions::PyTypeError::new_err(
"PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]",
))
@ -385,9 +389,9 @@ impl PyTokenizer {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.tokenizer).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to pickle Tokenizer: {}",
e.to_string()
e
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
@ -397,9 +401,9 @@ impl PyTokenizer {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.tokenizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Tokenizer: {}",
e.to_string()
e
))
})?;
Ok(())
@ -429,9 +433,9 @@ impl PyTokenizer {
#[staticmethod]
fn from_buffer(buffer: &PyBytes) -> PyResult<Self> {
let tokenizer = serde_json::from_slice(buffer.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyValueError::new_err(format!(
"Cannot instantiate Tokenizer from buffer: {}",
e.to_string()
e
))
})?;
Ok(Self { tokenizer })
@ -485,7 +489,7 @@ impl PyTokenizer {
one of `longest_first`, `only_first`, or `only_second`",
value
))
.into_pyerr()),
.into_pyerr::<exceptions::PyValueError>()),
}?
}
_ => println!("Ignored unknown kwarg option {}", key),
@ -533,7 +537,7 @@ impl PyTokenizer {
one of `left` or `right`",
other
))
.into_pyerr()),
.into_pyerr::<exceptions::PyValueError>()),
}?;
}
"pad_to_multiple_of" => {
@ -716,7 +720,7 @@ impl PyTokenizer {
token.is_special_token = false;
Ok(token.get_token())
} else {
Err(exceptions::Exception::py_err(
Err(exceptions::PyTypeError::new_err(
"Input must be a List[Union[str, AddedToken]]",
))
}
@ -736,7 +740,7 @@ impl PyTokenizer {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::Exception::py_err(
Err(exceptions::PyTypeError::new_err(
"Input must be a List[Union[str, AddedToken]]",
))
}
@ -747,10 +751,7 @@ impl PyTokenizer {
}
fn train(&mut self, trainer: &PyTrainer, files: Vec<String>) -> PyResult<()> {
self.tokenizer
.train_and_replace(trainer, files)
.map_err(|e| exceptions::Exception::py_err(format!("{}", e)))?;
Ok(())
ToPyResult(self.tokenizer.train_and_replace(trainer, files)).into()
}
#[args(pair = "None", add_special_tokens = true)]

View File

@ -73,7 +73,7 @@ impl PyBpeTrainer {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::Exception::py_err(
Err(exceptions::PyTypeError::new_err(
"special_tokens must be a List[Union[str, AddedToken]]",
))
}
@ -137,7 +137,7 @@ impl PyWordPieceTrainer {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::Exception::py_err(
Err(exceptions::PyTypeError::new_err(
"special_tokens must be a List[Union[str, AddedToken]]",
))
}
@ -205,7 +205,7 @@ impl PyUnigramTrainer {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::Exception::py_err(
Err(exceptions::PyTypeError::new_err(
"special_tokens must be a List[Union[str, AddedToken]]",
))
}
@ -220,9 +220,10 @@ impl PyUnigramTrainer {
}
}
let trainer: tokenizers::models::unigram::UnigramTrainer = builder
.build()
.map_err(|_| exceptions::Exception::py_err("Cannot build UnigramTrainer"))?;
let trainer: tokenizers::models::unigram::UnigramTrainer =
builder.build().map_err(|e| {
exceptions::PyException::new_err(format!("Cannot build UnigramTrainer: {}", e))
})?;
Ok((PyUnigramTrainer {}, PyTrainer::new(trainer.into())))
}
}

View File

@ -238,12 +238,12 @@ class TestTokenizer:
)
# Mal formed
with pytest.raises(ValueError, match="InputSequence must be str"):
with pytest.raises(TypeError, match="TextInputSequence must be str"):
tokenizer.encode([["my", "name"]])
tokenizer.encode("My name is john", [["pair"]])
tokenizer.encode("my name is john", ["pair"])
with pytest.raises(ValueError, match="InputSequence must be Union[List[str]"):
with pytest.raises(TypeError, match="InputSequence must be Union[List[str]"):
tokenizer.encode("My name is john", is_pretokenized=True)
tokenizer.encode("My name is john", ["pair"], is_pretokenized=True)
tokenizer.encode(["My", "name", "is", "John"], "pair", is_pretokenized=True)

View File

@ -366,9 +366,9 @@ impl Unigram {
/// use tokenizers::models::unigram::Unigram;
/// use std::path::Path;
///
/// let model = Unigram::load(Path::new("mymodel-unigram.json")).unwrap();
/// let model = Unigram::load("mymodel-unigram.json").unwrap();
/// ```
pub fn load(path: &Path) -> Result<Unigram> {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Unigram> {
let file = File::open(path).unwrap();
let reader = BufReader::new(file);
let u = serde_json::from_reader(reader)?;