mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Update PyO3 (#426)
This commit is contained in:
26
bindings/python/Cargo.lock
generated
26
bindings/python/Cargo.lock
generated
@ -484,14 +484,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
source = "git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01#e331befa27fede78d4662edf08fa0508db39be01"
|
||||
dependencies = [
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"pyo3 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -580,7 +580,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pyo3"
|
||||
version = "0.11.1"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"ctor 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
@ -589,13 +589,13 @@ dependencies = [
|
||||
"libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"parking_lot 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"paste 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"pyo3cls 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"pyo3cls 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unindent 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-derive-backend"
|
||||
version = "0.11.1"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"proc-macro2 1.0.21 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
@ -605,10 +605,10 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pyo3cls"
|
||||
version = "0.11.1"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"pyo3-derive-backend 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"pyo3-derive-backend 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
@ -895,8 +895,8 @@ dependencies = [
|
||||
"env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"numpy 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"numpy 0.11.0 (git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01)",
|
||||
"pyo3 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
@ -1037,7 +1037,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
|
||||
"checksum num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
|
||||
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
|
||||
"checksum numpy 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1afa393812b0e4c33d8a46054028ce39d79f954dfa98298f25691abf00b24e39"
|
||||
"checksum numpy 0.11.0 (git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01)" = "<none>"
|
||||
"checksum onig 6.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8a155d13862da85473665694f4c05d77fb96598bdceeaf696433c84ea9567e20"
|
||||
"checksum onig_sys 69.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "9bff06597a6b17855040955cae613af000fc0bfc8ad49ea68b9479a74e59292d"
|
||||
"checksum parking_lot 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a4893845fa2ca272e647da5d0e46660a314ead9c2fdd9a883aabc32e481a8733"
|
||||
@ -1048,9 +1048,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum ppv-lite86 0.2.9 (registry+https://github.com/rust-lang/crates.io-index)" = "c36fa947111f5c62a733b652544dd0016a43ce89619538a8ef92724a6f501a20"
|
||||
"checksum proc-macro-hack 0.5.18 (registry+https://github.com/rust-lang/crates.io-index)" = "99c605b9a0adc77b7211c6b1f722dcb613d68d66859a44f3d485a6da332b0598"
|
||||
"checksum proc-macro2 1.0.21 (registry+https://github.com/rust-lang/crates.io-index)" = "36e28516df94f3dd551a587da5357459d9b36d945a7c37c3557928c1c2ff2a2c"
|
||||
"checksum pyo3 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "9ca8710ffa8211c9a62a8a3863c4267c710dc42a82a7fd29c97de465d7ea6b7d"
|
||||
"checksum pyo3-derive-backend 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "58ad070bf6967b0d29ea74931ffcf9c6bbe8402a726e9afbeafadc0a287cc2b3"
|
||||
"checksum pyo3cls 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c3fa17e1ea569d0bf3b7c00f2a9eea831ca05e55dd76f1794c541abba1c64baa"
|
||||
"checksum pyo3 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4fe7ed6655de90fd3bf799f25ce8b695e89eeb098d03db802723617649fcaa74"
|
||||
"checksum pyo3-derive-backend 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d021655b5f22aeee2eaf3c3f88503250ad6cf90e37a875e1bb1854fceb86b3d0"
|
||||
"checksum pyo3cls 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "228ee516c00ac54fb70032d8cd928f213e70c87c340cdf44e2eba028d1f613ea"
|
||||
"checksum quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||
"checksum quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)" = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37"
|
||||
"checksum rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
|
||||
|
@ -14,12 +14,10 @@ serde = { version = "1.0", features = [ "rc", "derive" ]}
|
||||
serde_json = "1.0"
|
||||
libc = "0.2"
|
||||
env_logger = "0.7.1"
|
||||
numpy = "0.11"
|
||||
pyo3 = "0.12"
|
||||
numpy = { git = "https://github.com/pyo3/rust-numpy/", rev = "e331befa27fede78d4662edf08fa0508db39be01" }
|
||||
ndarray = "0.13"
|
||||
|
||||
[dependencies.pyo3]
|
||||
version = "0.11"
|
||||
|
||||
[dependencies.tokenizers]
|
||||
version = "*"
|
||||
path = "../../tokenizers"
|
||||
|
@ -31,21 +31,15 @@ impl PyDecoder {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
match &self.decoder {
|
||||
PyDecoderWrapper::Custom(_) => Py::new(py, base).map(Into::into),
|
||||
Ok(match &self.decoder {
|
||||
PyDecoderWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
|
||||
PyDecoderWrapper::Wrapped(inner) => match inner.as_ref() {
|
||||
DecoderWrapper::Metaspace(_) => {
|
||||
Py::new(py, (PyMetaspaceDec {}, base)).map(Into::into)
|
||||
}
|
||||
DecoderWrapper::WordPiece(_) => {
|
||||
Py::new(py, (PyWordPieceDec {}, base)).map(Into::into)
|
||||
}
|
||||
DecoderWrapper::ByteLevel(_) => {
|
||||
Py::new(py, (PyByteLevelDec {}, base)).map(Into::into)
|
||||
}
|
||||
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base)).map(Into::into),
|
||||
DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,7 +59,7 @@ impl PyDecoder {
|
||||
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(&self.decoder).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to pickle Decoder: {}",
|
||||
e
|
||||
))
|
||||
@ -77,7 +71,7 @@ impl PyDecoder {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
self.decoder = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to unpickle Decoder: {}",
|
||||
e
|
||||
))
|
||||
@ -143,7 +137,7 @@ impl PyMetaspaceDec {
|
||||
"replacement" => {
|
||||
let s: &str = value.extract()?;
|
||||
replacement = s.chars().next().ok_or_else(|| {
|
||||
exceptions::Exception::py_err("replacement must be a character")
|
||||
exceptions::PyValueError::new_err("replacement must be a character")
|
||||
})?;
|
||||
}
|
||||
"add_prefix_space" => add_prefix_space = value.extract()?,
|
||||
@ -202,9 +196,7 @@ impl Decoder for CustomDecoder {
|
||||
Ok(res) => Ok(res
|
||||
.cast_as::<PyString>(py)
|
||||
.map_err(|_| PyError::from("`decode` is expected to return a str"))?
|
||||
.to_string()
|
||||
.map_err(|_| PyError::from("`decode` is expected to return a str"))?
|
||||
.into_owned()),
|
||||
.to_string()),
|
||||
Err(e) => {
|
||||
e.print(py);
|
||||
Err(Box::new(PyError::from("Error while calling `decode`")))
|
||||
@ -273,7 +265,7 @@ impl Decoder for PyDecoderWrapper {
|
||||
mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
use pyo3::{AsPyRef, Py, PyObject, Python};
|
||||
use pyo3::prelude::*;
|
||||
use tk::decoders::metaspace::Metaspace;
|
||||
use tk::decoders::DecoderWrapper;
|
||||
|
||||
@ -306,8 +298,9 @@ mod test {
|
||||
_ => panic!("Expected wrapped, not custom."),
|
||||
}
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let py_msp = PyDecoder::new(Metaspace::default().into());
|
||||
let obj: PyObject = Py::new(gil.python(), py_msp).unwrap().into();
|
||||
let obj: PyObject = Py::new(py, py_msp).unwrap().into_py(py);
|
||||
let py_seq = PyDecoderWrapper::Custom(Arc::new(CustomDecoder::new(obj).unwrap()));
|
||||
assert!(serde_json::to_string(&py_seq).is_err());
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ impl PyEncoding {
|
||||
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(&self.encoding).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to pickle Encoding: {}",
|
||||
e.to_string()
|
||||
))
|
||||
@ -60,7 +60,7 @@ impl PyEncoding {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
self.encoding = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to unpickle Encoding: {}",
|
||||
e.to_string()
|
||||
))
|
||||
@ -171,7 +171,7 @@ impl PyEncoding {
|
||||
one of `left` or `right`",
|
||||
other
|
||||
))
|
||||
.into_pyerr()),
|
||||
.into_pyerr::<exceptions::PyValueError>()),
|
||||
}?;
|
||||
}
|
||||
"pad_id" => pad_id = value.extract()?,
|
||||
|
@ -1,5 +1,6 @@
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::type_object::PyTypeObject;
|
||||
use std::fmt::{Display, Formatter, Result as FmtResult};
|
||||
use tokenizers::tokenizer::Result;
|
||||
|
||||
@ -9,8 +10,8 @@ impl PyError {
|
||||
pub fn from(s: &str) -> Self {
|
||||
PyError(String::from(s))
|
||||
}
|
||||
pub fn into_pyerr(self) -> PyErr {
|
||||
exceptions::Exception::py_err(format!("{}", self))
|
||||
pub fn into_pyerr<T: PyTypeObject>(self) -> PyErr {
|
||||
PyErr::new::<T, _>(format!("{}", self))
|
||||
}
|
||||
}
|
||||
impl Display for PyError {
|
||||
@ -24,7 +25,7 @@ pub struct ToPyResult<T>(pub Result<T>);
|
||||
impl<T> std::convert::Into<PyResult<T>> for ToPyResult<T> {
|
||||
fn into(self) -> PyResult<T> {
|
||||
self.0
|
||||
.map_err(|e| exceptions::Exception::py_err(format!("{}", e)))
|
||||
.map_err(|e| exceptions::PyException::new_err(format!("{}", e)))
|
||||
}
|
||||
}
|
||||
impl<T> ToPyResult<T> {
|
||||
|
@ -34,12 +34,12 @@ impl PyModel {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
match self.model.as_ref() {
|
||||
ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base)).map(Into::into),
|
||||
ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base)).map(Into::into),
|
||||
ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base)).map(Into::into),
|
||||
ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base)).map(Into::into),
|
||||
}
|
||||
Ok(match self.model.as_ref() {
|
||||
ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?.into_py(py),
|
||||
ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?.into_py(py),
|
||||
ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?.into_py(py),
|
||||
ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base))?.into_py(py),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,7 +82,7 @@ impl PyModel {
|
||||
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(&self.model).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to pickle Model: {}",
|
||||
e.to_string()
|
||||
))
|
||||
@ -94,7 +94,7 @@ impl PyModel {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
self.model = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to unpickle Model: {}",
|
||||
e.to_string()
|
||||
))
|
||||
@ -130,7 +130,7 @@ impl PyBPE {
|
||||
kwargs: Option<&PyDict>,
|
||||
) -> PyResult<(Self, PyModel)> {
|
||||
if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) {
|
||||
return Err(exceptions::ValueError::py_err(
|
||||
return Err(exceptions::PyValueError::new_err(
|
||||
"`vocab` and `merges` must be both specified",
|
||||
));
|
||||
}
|
||||
@ -164,7 +164,7 @@ impl PyBPE {
|
||||
}
|
||||
|
||||
match builder.build() {
|
||||
Err(e) => Err(exceptions::Exception::py_err(format!(
|
||||
Err(e) => Err(exceptions::PyException::new_err(format!(
|
||||
"Error while initializing BPE: {}",
|
||||
e
|
||||
))),
|
||||
@ -207,12 +207,10 @@ impl PyWordPiece {
|
||||
}
|
||||
|
||||
match builder.build() {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Error while initializing WordPiece",
|
||||
))
|
||||
}
|
||||
Err(e) => Err(exceptions::PyException::new_err(format!(
|
||||
"Error while initializing WordPiece: {}",
|
||||
e
|
||||
))),
|
||||
Ok(wordpiece) => Ok((PyWordPiece {}, PyModel::new(Arc::new(wordpiece.into())))),
|
||||
}
|
||||
}
|
||||
@ -240,12 +238,10 @@ impl PyWordLevel {
|
||||
|
||||
if let Some(vocab) = vocab {
|
||||
match WordLevel::from_files(vocab, unk_token) {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Error while initializing WordLevel",
|
||||
))
|
||||
}
|
||||
Err(e) => Err(exceptions::PyException::new_err(format!(
|
||||
"Error while initializing WordLevel: {}",
|
||||
e
|
||||
))),
|
||||
Ok(model) => Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into())))),
|
||||
}
|
||||
} else {
|
||||
@ -263,13 +259,13 @@ pub struct PyUnigram {}
|
||||
#[pymethods]
|
||||
impl PyUnigram {
|
||||
#[new]
|
||||
fn new(vocab: Option<String>) -> PyResult<(Self, PyModel)> {
|
||||
fn new(vocab: Option<&str>) -> PyResult<(Self, PyModel)> {
|
||||
match vocab {
|
||||
Some(vocab) => match Unigram::load(&std::path::Path::new(&vocab)) {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err("Error while loading Unigram"))
|
||||
}
|
||||
Some(vocab) => match Unigram::load(vocab) {
|
||||
Err(e) => Err(exceptions::PyException::new_err(format!(
|
||||
"Error while loading Unigram: {}",
|
||||
e
|
||||
))),
|
||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||
},
|
||||
None => Ok((
|
||||
@ -283,7 +279,7 @@ impl PyUnigram {
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::models::PyModel;
|
||||
use pyo3::{AsPyRef, Python};
|
||||
use pyo3::prelude::*;
|
||||
use std::sync::Arc;
|
||||
use tk::models::bpe::BPE;
|
||||
use tk::models::ModelWrapper;
|
||||
|
@ -29,35 +29,31 @@ impl PyNormalizer {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
match self.normalizer {
|
||||
PyNormalizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base)).map(Into::into),
|
||||
Ok(match self.normalizer {
|
||||
PyNormalizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
|
||||
PyNormalizerWrapper::Wrapped(ref inner) => match inner.as_ref() {
|
||||
NormalizerWrapper::Sequence(_) => {
|
||||
Py::new(py, (PySequence {}, base)).map(Into::into)
|
||||
}
|
||||
NormalizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
|
||||
NormalizerWrapper::BertNormalizer(_) => {
|
||||
Py::new(py, (PyBertNormalizer {}, base)).map(Into::into)
|
||||
Py::new(py, (PyBertNormalizer {}, base))?.into_py(py)
|
||||
}
|
||||
NormalizerWrapper::StripNormalizer(_) => {
|
||||
Py::new(py, (PyBertNormalizer {}, base)).map(Into::into)
|
||||
Py::new(py, (PyBertNormalizer {}, base))?.into_py(py)
|
||||
}
|
||||
NormalizerWrapper::StripAccents(_) => {
|
||||
Py::new(py, (PyStripAccents {}, base)).map(Into::into)
|
||||
}
|
||||
NormalizerWrapper::NFC(_) => Py::new(py, (PyNFC {}, base)).map(Into::into),
|
||||
NormalizerWrapper::NFD(_) => Py::new(py, (PyNFD {}, base)).map(Into::into),
|
||||
NormalizerWrapper::NFKC(_) => Py::new(py, (PyNFKC {}, base)).map(Into::into),
|
||||
NormalizerWrapper::NFKD(_) => Py::new(py, (PyNFKD {}, base)).map(Into::into),
|
||||
NormalizerWrapper::Lowercase(_) => {
|
||||
Py::new(py, (PyLowercase {}, base)).map(Into::into)
|
||||
Py::new(py, (PyStripAccents {}, base))?.into_py(py)
|
||||
}
|
||||
NormalizerWrapper::NFC(_) => Py::new(py, (PyNFC {}, base))?.into_py(py),
|
||||
NormalizerWrapper::NFD(_) => Py::new(py, (PyNFD {}, base))?.into_py(py),
|
||||
NormalizerWrapper::NFKC(_) => Py::new(py, (PyNFKC {}, base))?.into_py(py),
|
||||
NormalizerWrapper::NFKD(_) => Py::new(py, (PyNFKD {}, base))?.into_py(py),
|
||||
NormalizerWrapper::Lowercase(_) => Py::new(py, (PyLowercase {}, base))?.into_py(py),
|
||||
NormalizerWrapper::Precompiled(_) => {
|
||||
Py::new(py, (PyPrecompiled {}, base)).map(Into::into)
|
||||
Py::new(py, (PyPrecompiled {}, base))?.into_py(py)
|
||||
}
|
||||
NormalizerWrapper::Replace(_) => Py::new(py, (PyReplace {}, base)).map(Into::into),
|
||||
NormalizerWrapper::Nmt(_) => Py::new(py, (PyNmt {}, base)).map(Into::into),
|
||||
NormalizerWrapper::Replace(_) => Py::new(py, (PyReplace {}, base))?.into_py(py),
|
||||
NormalizerWrapper::Nmt(_) => Py::new(py, (PyNmt {}, base))?.into_py(py),
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -71,9 +67,9 @@ impl Normalizer for PyNormalizer {
|
||||
impl PyNormalizer {
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(&self.normalizer).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to pickle Normalizer: {}",
|
||||
e.to_string()
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
||||
@ -83,9 +79,9 @@ impl PyNormalizer {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
self.normalizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to unpickle Normalizer: {}",
|
||||
e.to_string()
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(())
|
||||
@ -315,9 +311,9 @@ impl PyPrecompiled {
|
||||
PyPrecompiled {},
|
||||
Precompiled::from(precompiled_charsmap)
|
||||
.map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to build Precompiled normalizer: {}",
|
||||
e.to_string()
|
||||
e
|
||||
))
|
||||
})?
|
||||
.into(),
|
||||
@ -337,7 +333,7 @@ impl PyReplace {
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use pyo3::{AsPyRef, Python};
|
||||
use pyo3::prelude::*;
|
||||
use tk::normalizers::unicode::{NFC, NFKC};
|
||||
use tk::normalizers::utils::Sequence;
|
||||
use tk::normalizers::NormalizerWrapper;
|
||||
|
@ -37,39 +37,35 @@ impl PyPreTokenizer {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
match &self.pretok {
|
||||
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base).map(Into::into),
|
||||
PyPreTokenizerWrapper::Sequence(_) => {
|
||||
Py::new(py, (PySequence {}, base)).map(Into::into)
|
||||
}
|
||||
Ok(match &self.pretok {
|
||||
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
|
||||
PyPreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
|
||||
PyPreTokenizerWrapper::Wrapped(inner) => match inner.as_ref() {
|
||||
PreTokenizerWrapper::Whitespace(_) => {
|
||||
Py::new(py, (PyWhitespace {}, base)).map(Into::into)
|
||||
Py::new(py, (PyWhitespace {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Punctuation(_) => {
|
||||
Py::new(py, (PyPunctuation {}, base)).map(Into::into)
|
||||
}
|
||||
PreTokenizerWrapper::Sequence(_) => {
|
||||
Py::new(py, (PySequence {}, base)).map(Into::into)
|
||||
Py::new(py, (PyPunctuation {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
|
||||
PreTokenizerWrapper::Metaspace(_) => {
|
||||
Py::new(py, (PyMetaspace {}, base)).map(Into::into)
|
||||
Py::new(py, (PyMetaspace {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Delimiter(_) => {
|
||||
Py::new(py, (PyCharDelimiterSplit {}, base)).map(Into::into)
|
||||
Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::WhitespaceSplit(_) => {
|
||||
Py::new(py, (PyWhitespaceSplit {}, base)).map(Into::into)
|
||||
Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::ByteLevel(_) => {
|
||||
Py::new(py, (PyByteLevel {}, base)).map(Into::into)
|
||||
Py::new(py, (PyByteLevel {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::BertPreTokenizer(_) => {
|
||||
Py::new(py, (PyBertPreTokenizer {}, base)).map(Into::into)
|
||||
Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base)).map(Into::into),
|
||||
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py),
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,7 +87,7 @@ impl PyPreTokenizer {
|
||||
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(&self.pretok).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to pickle PreTokenizer: {}",
|
||||
e.to_string()
|
||||
))
|
||||
@ -103,9 +99,9 @@ impl PyPreTokenizer {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to unpickle PreTokenizer: {}",
|
||||
e.to_string()
|
||||
e
|
||||
))
|
||||
})?;
|
||||
self.pretok = unpickled;
|
||||
@ -187,10 +183,9 @@ pub struct PyCharDelimiterSplit {}
|
||||
impl PyCharDelimiterSplit {
|
||||
#[new]
|
||||
pub fn new(delimiter: &str) -> PyResult<(Self, PyPreTokenizer)> {
|
||||
let chr_delimiter = delimiter
|
||||
.chars()
|
||||
.next()
|
||||
.ok_or_else(|| exceptions::Exception::py_err("delimiter must be a single character"))?;
|
||||
let chr_delimiter = delimiter.chars().next().ok_or_else(|| {
|
||||
exceptions::PyValueError::new_err("delimiter must be a single character")
|
||||
})?;
|
||||
Ok((
|
||||
PyCharDelimiterSplit {},
|
||||
CharDelimiterSplit::new(chr_delimiter).into(),
|
||||
@ -267,7 +262,7 @@ impl PyMetaspace {
|
||||
"replacement" => {
|
||||
let s: &str = value.extract()?;
|
||||
replacement = s.chars().next().ok_or_else(|| {
|
||||
exceptions::Exception::py_err("replacement must be a character")
|
||||
exceptions::PyValueError::new_err("replacement must be a character")
|
||||
})?;
|
||||
}
|
||||
"add_prefix_space" => add_prefix_space = value.extract()?,
|
||||
@ -417,7 +412,7 @@ impl PreTokenizer for PyPreTokenizerWrapper {
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use pyo3::{AsPyRef, Py, PyObject, Python};
|
||||
use pyo3::prelude::*;
|
||||
use tk::pre_tokenizers::whitespace::Whitespace;
|
||||
use tk::pre_tokenizers::PreTokenizerWrapper;
|
||||
|
||||
@ -451,8 +446,9 @@ mod test {
|
||||
_ => panic!("Expected wrapped, not custom."),
|
||||
}
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let py_wsp = PyPreTokenizer::new(Whitespace::default().into());
|
||||
let obj: PyObject = Py::new(gil.python(), py_wsp).unwrap().into();
|
||||
let obj: PyObject = Py::new(py, py_wsp).unwrap().into_py(py);
|
||||
let py_seq: PyPreTokenizerWrapper =
|
||||
PyPreTokenizerWrapper::Custom(Arc::new(CustomPreTokenizer::new(obj).unwrap()));
|
||||
assert!(serde_json::to_string(&py_seq).is_err());
|
||||
|
@ -30,20 +30,16 @@ impl PyPostProcessor {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
match self.processor.as_ref() {
|
||||
PostProcessorWrapper::ByteLevel(_) => {
|
||||
Py::new(py, (PyByteLevel {}, base)).map(Into::into)
|
||||
}
|
||||
PostProcessorWrapper::Bert(_) => {
|
||||
Py::new(py, (PyBertProcessing {}, base)).map(Into::into)
|
||||
}
|
||||
Ok(match self.processor.as_ref() {
|
||||
PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?.into_py(py),
|
||||
PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?.into_py(py),
|
||||
PostProcessorWrapper::Roberta(_) => {
|
||||
Py::new(py, (PyRobertaProcessing {}, base)).map(Into::into)
|
||||
Py::new(py, (PyRobertaProcessing {}, base))?.into_py(py)
|
||||
}
|
||||
PostProcessorWrapper::Template(_) => {
|
||||
Py::new(py, (PyTemplateProcessing {}, base)).map(Into::into)
|
||||
Py::new(py, (PyTemplateProcessing {}, base))?.into_py(py)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,7 +63,7 @@ impl PostProcessor for PyPostProcessor {
|
||||
impl PyPostProcessor {
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to pickle PostProcessor: {}",
|
||||
e.to_string()
|
||||
))
|
||||
@ -79,7 +75,7 @@ impl PyPostProcessor {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
self.processor = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to unpickle PostProcessor: {}",
|
||||
e.to_string()
|
||||
))
|
||||
@ -181,11 +177,11 @@ impl FromPyObject<'_> for PySpecialToken {
|
||||
} else if let Ok(d) = ob.downcast::<PyDict>() {
|
||||
let id = d
|
||||
.get_item("id")
|
||||
.ok_or_else(|| exceptions::ValueError::py_err("`id` must be specified"))?
|
||||
.ok_or_else(|| exceptions::PyValueError::new_err("`id` must be specified"))?
|
||||
.extract::<String>()?;
|
||||
let ids = d
|
||||
.get_item("ids")
|
||||
.ok_or_else(|| exceptions::ValueError::py_err("`ids` must be specified"))?
|
||||
.ok_or_else(|| exceptions::PyValueError::new_err("`ids` must be specified"))?
|
||||
.extract::<Vec<u32>>()?;
|
||||
let type_ids = d.get_item("type_ids").map_or_else(
|
||||
|| Ok(vec![None; ids.len()]),
|
||||
@ -193,14 +189,14 @@ impl FromPyObject<'_> for PySpecialToken {
|
||||
)?;
|
||||
let tokens = d
|
||||
.get_item("tokens")
|
||||
.ok_or_else(|| exceptions::ValueError::py_err("`tokens` must be specified"))?
|
||||
.ok_or_else(|| exceptions::PyValueError::new_err("`tokens` must be specified"))?
|
||||
.extract::<Vec<String>>()?;
|
||||
|
||||
Ok(Self(
|
||||
ToPyResult(SpecialToken::new(id, ids, type_ids, tokens)).into_py()?,
|
||||
))
|
||||
} else {
|
||||
Err(exceptions::TypeError::py_err(
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Expected Union[Tuple[str, int], Tuple[int, str], dict]",
|
||||
))
|
||||
}
|
||||
@ -223,7 +219,7 @@ impl FromPyObject<'_> for PyTemplate {
|
||||
} else if let Ok(s) = ob.extract::<Vec<&str>>() {
|
||||
Ok(Self(s.into()))
|
||||
} else {
|
||||
Err(exceptions::TypeError::py_err(
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Expected Union[str, List[str]]",
|
||||
))
|
||||
}
|
||||
@ -252,7 +248,7 @@ impl PyTemplateProcessing {
|
||||
if let Some(sp) = special_tokens {
|
||||
builder.special_tokens(sp);
|
||||
}
|
||||
let processor = builder.build().map_err(exceptions::ValueError::py_err)?;
|
||||
let processor = builder.build().map_err(exceptions::PyValueError::new_err)?;
|
||||
|
||||
Ok((
|
||||
PyTemplateProcessing {},
|
||||
@ -265,7 +261,7 @@ impl PyTemplateProcessing {
|
||||
mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
use pyo3::{AsPyRef, Python};
|
||||
use pyo3::prelude::*;
|
||||
use tk::processors::bert::BertProcessing;
|
||||
use tk::processors::PostProcessorWrapper;
|
||||
|
||||
|
@ -175,9 +175,9 @@ impl PyObjectProtocol for PyAddedToken {
|
||||
struct TextInputSequence<'s>(tk::InputSequence<'s>);
|
||||
impl<'s> FromPyObject<'s> for TextInputSequence<'s> {
|
||||
fn extract(ob: &'s PyAny) -> PyResult<Self> {
|
||||
let err = exceptions::ValueError::py_err("TextInputSequence must be str");
|
||||
let err = exceptions::PyTypeError::new_err("TextInputSequence must be str");
|
||||
if let Ok(s) = ob.downcast::<PyString>() {
|
||||
Ok(Self(s.to_string().map_err(|_| err)?.into()))
|
||||
Ok(Self(s.to_string_lossy().into()))
|
||||
} else {
|
||||
Err(err)
|
||||
}
|
||||
@ -207,7 +207,9 @@ impl FromPyObject<'_> for PyArrayUnicode {
|
||||
|
||||
// type_num == 19 => Unicode
|
||||
if type_num != 19 {
|
||||
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
|
||||
return Err(exceptions::PyTypeError::new_err(
|
||||
"Expected a np.array[dtype='U']",
|
||||
));
|
||||
}
|
||||
|
||||
unsafe {
|
||||
@ -224,7 +226,7 @@ impl FromPyObject<'_> for PyArrayUnicode {
|
||||
let py = gil.python();
|
||||
let obj = PyObject::from_owned_ptr(py, unicode);
|
||||
let s = obj.cast_as::<PyString>(py)?;
|
||||
Ok(s.to_string()?.trim_matches(char::from(0)).to_owned())
|
||||
Ok(s.to_string_lossy().trim_matches(char::from(0)).to_owned())
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
@ -247,7 +249,9 @@ impl FromPyObject<'_> for PyArrayStr {
|
||||
let n_elem = array.shape()[0];
|
||||
|
||||
if type_num != 17 {
|
||||
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
|
||||
return Err(exceptions::PyTypeError::new_err(
|
||||
"Expected a np.array[dtype='O']",
|
||||
));
|
||||
}
|
||||
|
||||
unsafe {
|
||||
@ -259,7 +263,7 @@ impl FromPyObject<'_> for PyArrayStr {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let s = obj.cast_as::<PyString>(py)?;
|
||||
Ok(s.to_string()?.into_owned())
|
||||
Ok(s.to_string_lossy().into_owned())
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
@ -292,7 +296,7 @@ impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
|
||||
return Ok(Self(seq.into()));
|
||||
}
|
||||
}
|
||||
Err(exceptions::ValueError::py_err(
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
|
||||
))
|
||||
}
|
||||
@ -319,7 +323,7 @@ impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
|
||||
return Ok(Self((first, second).into()));
|
||||
}
|
||||
}
|
||||
Err(exceptions::ValueError::py_err(
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
|
||||
))
|
||||
}
|
||||
@ -346,7 +350,7 @@ impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
|
||||
return Ok(Self((first, second).into()));
|
||||
}
|
||||
}
|
||||
Err(exceptions::ValueError::py_err(
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \
|
||||
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]",
|
||||
))
|
||||
@ -385,9 +389,9 @@ impl PyTokenizer {
|
||||
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(&self.tokenizer).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to pickle Tokenizer: {}",
|
||||
e.to_string()
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
||||
@ -397,9 +401,9 @@ impl PyTokenizer {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
self.tokenizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to unpickle Tokenizer: {}",
|
||||
e.to_string()
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(())
|
||||
@ -429,9 +433,9 @@ impl PyTokenizer {
|
||||
#[staticmethod]
|
||||
fn from_buffer(buffer: &PyBytes) -> PyResult<Self> {
|
||||
let tokenizer = serde_json::from_slice(buffer.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
exceptions::PyValueError::new_err(format!(
|
||||
"Cannot instantiate Tokenizer from buffer: {}",
|
||||
e.to_string()
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(Self { tokenizer })
|
||||
@ -485,7 +489,7 @@ impl PyTokenizer {
|
||||
one of `longest_first`, `only_first`, or `only_second`",
|
||||
value
|
||||
))
|
||||
.into_pyerr()),
|
||||
.into_pyerr::<exceptions::PyValueError>()),
|
||||
}?
|
||||
}
|
||||
_ => println!("Ignored unknown kwarg option {}", key),
|
||||
@ -533,7 +537,7 @@ impl PyTokenizer {
|
||||
one of `left` or `right`",
|
||||
other
|
||||
))
|
||||
.into_pyerr()),
|
||||
.into_pyerr::<exceptions::PyValueError>()),
|
||||
}?;
|
||||
}
|
||||
"pad_to_multiple_of" => {
|
||||
@ -716,7 +720,7 @@ impl PyTokenizer {
|
||||
token.is_special_token = false;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Input must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
@ -736,7 +740,7 @@ impl PyTokenizer {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Input must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
@ -747,10 +751,7 @@ impl PyTokenizer {
|
||||
}
|
||||
|
||||
fn train(&mut self, trainer: &PyTrainer, files: Vec<String>) -> PyResult<()> {
|
||||
self.tokenizer
|
||||
.train_and_replace(trainer, files)
|
||||
.map_err(|e| exceptions::Exception::py_err(format!("{}", e)))?;
|
||||
Ok(())
|
||||
ToPyResult(self.tokenizer.train_and_replace(trainer, files)).into()
|
||||
}
|
||||
|
||||
#[args(pair = "None", add_special_tokens = true)]
|
||||
|
@ -73,7 +73,7 @@ impl PyBpeTrainer {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
@ -137,7 +137,7 @@ impl PyWordPieceTrainer {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
@ -205,7 +205,7 @@ impl PyUnigramTrainer {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
@ -220,9 +220,10 @@ impl PyUnigramTrainer {
|
||||
}
|
||||
}
|
||||
|
||||
let trainer: tokenizers::models::unigram::UnigramTrainer = builder
|
||||
.build()
|
||||
.map_err(|_| exceptions::Exception::py_err("Cannot build UnigramTrainer"))?;
|
||||
let trainer: tokenizers::models::unigram::UnigramTrainer =
|
||||
builder.build().map_err(|e| {
|
||||
exceptions::PyException::new_err(format!("Cannot build UnigramTrainer: {}", e))
|
||||
})?;
|
||||
Ok((PyUnigramTrainer {}, PyTrainer::new(trainer.into())))
|
||||
}
|
||||
}
|
||||
|
@ -238,12 +238,12 @@ class TestTokenizer:
|
||||
)
|
||||
|
||||
# Mal formed
|
||||
with pytest.raises(ValueError, match="InputSequence must be str"):
|
||||
with pytest.raises(TypeError, match="TextInputSequence must be str"):
|
||||
tokenizer.encode([["my", "name"]])
|
||||
tokenizer.encode("My name is john", [["pair"]])
|
||||
tokenizer.encode("my name is john", ["pair"])
|
||||
|
||||
with pytest.raises(ValueError, match="InputSequence must be Union[List[str]"):
|
||||
with pytest.raises(TypeError, match="InputSequence must be Union[List[str]"):
|
||||
tokenizer.encode("My name is john", is_pretokenized=True)
|
||||
tokenizer.encode("My name is john", ["pair"], is_pretokenized=True)
|
||||
tokenizer.encode(["My", "name", "is", "John"], "pair", is_pretokenized=True)
|
||||
|
@ -366,9 +366,9 @@ impl Unigram {
|
||||
/// use tokenizers::models::unigram::Unigram;
|
||||
/// use std::path::Path;
|
||||
///
|
||||
/// let model = Unigram::load(Path::new("mymodel-unigram.json")).unwrap();
|
||||
/// let model = Unigram::load("mymodel-unigram.json").unwrap();
|
||||
/// ```
|
||||
pub fn load(path: &Path) -> Result<Unigram> {
|
||||
pub fn load<P: AsRef<Path>>(path: P) -> Result<Unigram> {
|
||||
let file = File::open(path).unwrap();
|
||||
let reader = BufReader::new(file);
|
||||
let u = serde_json::from_reader(reader)?;
|
||||
|
Reference in New Issue
Block a user