Python - Update pyo3 version

* Use __new__ instead of static method as model constructors
This commit is contained in:
Bjarte Johansen
2020-04-06 21:16:15 +02:00
parent 2a4e5f81de
commit 2dc48e56ac
10 changed files with 322 additions and 211 deletions

View File

@ -73,6 +73,14 @@ dependencies = [
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "cloudabi"
version = "0.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "console"
version = "0.9.1"
@ -243,6 +251,14 @@ name = "libc"
version = "0.2.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "lock_api"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "memchr"
version = "2.2.1"
@ -278,6 +294,28 @@ name = "number_prefix"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "parking_lot"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"lock_api 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
"parking_lot_core 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "parking_lot_core"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
"redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)",
"smallvec 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "paste"
version = "0.1.6"
@ -323,27 +361,26 @@ dependencies = [
[[package]]
name = "pyo3"
version = "0.8.4"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"indoc 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"inventory 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)",
"parking_lot 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)",
"paste 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3cls 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3cls 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)",
"spin 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)",
"unindent 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
"version_check 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "pyo3-derive-backend"
version = "0.8.4"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro2 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)",
@ -353,11 +390,10 @@ dependencies = [
[[package]]
name = "pyo3cls"
version = "0.8.4"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro2 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3-derive-backend 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3-derive-backend 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 1.0.11 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -429,6 +465,11 @@ dependencies = [
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "redox_syscall"
version = "0.1.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "regex"
version = "1.3.1"
@ -509,11 +550,6 @@ name = "smallvec"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "spin"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "strsim"
version = "0.8.0"
@ -574,7 +610,7 @@ dependencies = [
name = "tokenizers-python"
version = "0.7.0-rc3"
dependencies = [
"pyo3 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"tokenizers 0.9.0",
]
@ -651,6 +687,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
"checksum clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5067f5bb2d80ef5d68b4c87db81601f0b75bca627bc2ef76b141d7b846a3c6d9"
"checksum clicolors-control 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90082ee5dcdd64dc4e9e0d37fbf3ee325419e39c0092191e0393df65518f741e"
"checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f"
"checksum console 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f5d540c2d34ac9dd0deb5f3b5f54c36c79efa78f6b3ad19106a554d07a7b5d9f"
"checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca"
"checksum crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5064ebdbf05ce3cb95e45c8b086f72263f4166b29b97f6baff7ef7fe047b55ac"
@ -670,19 +707,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "501266b7edd0174f8530248f87f99c88fbe60ca4ef3dd486835b8d8d53136f7f"
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
"checksum libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)" = "d515b1f41455adea1313a4a2ac8a8a477634fbae63cc6100e3aebb207ce61558"
"checksum lock_api 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "79b2de95ecb4691949fea4716ca53cdbcfccb2c612e19644a8bad05edcf9f47b"
"checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e"
"checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9"
"checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4"
"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72"
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
"checksum parking_lot 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "92e98c49ab0b7ce5b222f2cc9193fc4efe11c6d0bd4f648e374684a6857b1cfc"
"checksum parking_lot_core 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7582838484df45743c8434fbff785e8edf260c28748353d44bc0da32e0ceabf1"
"checksum paste 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "423a519e1c6e828f1e73b720f9d9ed2fa643dce8a7737fb43235ce0b41eeaa49"
"checksum paste-impl 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "4214c9e912ef61bf42b81ba9a47e8aad1b2ffaf739ab162bf96d1e011f54e6c5"
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
"checksum proc-macro-hack 0.5.11 (registry+https://github.com/rust-lang/crates.io-index)" = "ecd45702f76d6d3c75a80564378ae228a85f0b59d2f3ed43c91b4a69eb2ebfc5"
"checksum proc-macro2 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "9c9e470a8dc4aeae2dee2f335e8f533e2d4b347e1434e5671afc49b054592f27"
"checksum pyo3 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)" = "7f9df1468dddf8a59ec799cf3b930bb75ec09deabe875ba953e06c51d1077136"
"checksum pyo3-derive-backend 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)" = "9f6e56fb3e97b344a8f87d036f94578399402c6b75949de6270cd07928f790b1"
"checksum pyo3cls 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)" = "97452dcdf5941627ebc5c06664a07821fc7fc88d7515f02178193a8ebe316468"
"checksum pyo3 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)" = "397fef05d982d84944edac8281972ed7108be996e5635fdfdae777d78632c8f0"
"checksum pyo3-derive-backend 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b8656b9d1a8c49838439b6e2c299ec927816c5c0c92b623f677d5579e9c04851"
"checksum pyo3cls 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ab91823d2634de5fd56af40da20fd89cac144a15da0d9101f3c2b047af9ed9ca"
"checksum quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "053a8c8bcc71fcce321828dc897a98ab9760bef03a4fc36693c231e5b3216cfe"
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
@ -690,6 +730,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
"checksum rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "db6ce3297f9c85e16621bb8cca38a06779ffc31bb8184e1be4bed2be4678a098"
"checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9"
"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84"
"checksum regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "dc220bd33bdce8f093101afe22a037b8eb0e5af33592e6a9caafff0d4cb81cbd"
"checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716"
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
@ -701,7 +742,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum serde_derive 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)" = "128f9e303a5a29922045a830221b8f78ec74a5f544944f3d5984f8ec3895ef64"
"checksum serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)" = "48c575e0cc52bdd09b47f330f646cf59afc586e9c4e3ccd6fc1f625b8ea1dad7"
"checksum smallvec 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "44e59e0c9fa00817912ae6e4e6e3c4fe04455e75699d06eedc7d85917ed8e8f4"
"checksum spin 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
"checksum syn 1.0.11 (registry+https://github.com/rust-lang/crates.io-index)" = "dff0acdb207ae2fe6d5976617f887eb1e35a2ba52c13c7234c790960cdad9238"
"checksum termios 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "72b620c5ea021d75a735c943269bb07d30c9b77d6ac6b236bc8b5c496ef05625"

View File

@ -12,7 +12,7 @@ crate-type = ["cdylib"]
rayon = "1.2.0"
[dependencies.pyo3]
version = "0.8.4"
version = "0.9.1"
features = ["extension-module"]
[dependencies.tokenizers]

View File

@ -31,10 +31,13 @@ pub struct ByteLevel {}
#[pymethods]
impl ByteLevel {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(Decoder {
fn new() -> PyResult<(Self, Decoder)> {
Ok((
ByteLevel {},
Decoder {
decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel::default())),
}))
},
))
}
}
@ -44,7 +47,7 @@ pub struct WordPiece {}
impl WordPiece {
#[new]
#[args(kwargs = "**")]
fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> {
let mut prefix = String::from("##");
let mut cleanup = true;
@ -57,11 +60,14 @@ impl WordPiece {
}
}
Ok(obj.init(Decoder {
Ok((
WordPiece {},
Decoder {
decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new(
prefix, cleanup,
))),
}))
},
))
}
}
@ -71,7 +77,7 @@ pub struct Metaspace {}
impl Metaspace {
#[new]
#[args(kwargs = "**")]
fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> {
let mut replacement = '▁';
let mut add_prefix_space = true;
@ -91,12 +97,15 @@ impl Metaspace {
}
}
Ok(obj.init(Decoder {
Ok((
Metaspace {},
Decoder {
decoder: Container::Owned(Box::new(tk::decoders::metaspace::Metaspace::new(
replacement,
add_prefix_space,
))),
}))
},
))
}
}
@ -106,7 +115,7 @@ pub struct BPEDecoder {}
impl BPEDecoder {
#[new]
#[args(kwargs = "**")]
fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> {
let mut suffix = String::from("</w>");
if let Some(kwargs) = kwargs {
@ -119,9 +128,12 @@ impl BPEDecoder {
}
}
Ok(obj.init(Decoder {
Ok((
BPEDecoder {},
Decoder {
decoder: Container::Owned(Box::new(tk::decoders::bpe::BPEDecoder::new(suffix))),
}))
},
))
}
}

View File

@ -40,7 +40,7 @@ impl PySequenceProtocol for Encoding {
impl Encoding {
#[staticmethod]
#[args(growing_offsets = true)]
fn merge(encodings: Vec<&Encoding>, growing_offsets: bool) -> Encoding {
fn merge(encodings: Vec<PyRef<Encoding>>, growing_offsets: bool) -> Encoding {
Encoding::new(tk::tokenizer::Encoding::merge(
encodings
.into_iter()

View File

@ -21,7 +21,7 @@ impl EncodeInput {
impl<'source> FromPyObject<'source> for EncodeInput {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let sequence: &PyList = ob.downcast_ref()?;
let sequence: &PyList = ob.downcast()?;
enum Mode {
NoOffsets,
@ -84,7 +84,7 @@ pub struct Model {
#[pymethods]
impl Model {
#[new]
fn new(_obj: &PyRawObject) -> PyResult<()> {
fn new() -> PyResult<Self> {
Err(exceptions::Exception::py_err(
"Cannot create a Model directly. Use a concrete subclass",
))
@ -142,19 +142,29 @@ impl Model {
/// BPE Model
/// Allows the creation of a BPE Model to be used with a Tokenizer
#[pyclass]
#[pyclass(extends=Model)]
pub struct BPE {}
#[pymethods]
impl BPE {
/// from_files(vocab, merges, /)
/// --
///
/// Instanciate a new BPE model using the provided vocab and merges files
#[staticmethod]
#[new]
#[args(kwargs = "**")]
fn from_files(vocab: &str, merges: &str, kwargs: Option<&PyDict>) -> PyResult<Model> {
let mut builder = tk::models::bpe::BPE::from_files(vocab, merges);
fn new(
vocab: Option<&str>,
merges: Option<&str>,
kwargs: Option<&PyDict>,
) -> PyResult<(Self, Model)> {
let mut builder = tk::models::bpe::BPE::builder();
if let Some(vocab) = vocab {
if let Some(merges) = merges {
builder = builder.files(vocab.to_owned(), merges.to_owned());
} else {
return Err(exceptions::Exception::py_err(format!(
"Got vocab file ({}), but missing merges",
vocab
)));
}
}
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
@ -184,38 +194,30 @@ impl BPE {
"Error while initializing BPE: {}",
e
))),
Ok(bpe) => Ok(Model {
model: Container::Owned(Box::new(bpe)),
}),
}
}
/// empty()
/// --
///
/// Instanciate a new BPE model with empty vocab and merges
#[staticmethod]
fn empty() -> Model {
Ok(bpe) => Ok((
BPE {},
Model {
model: Container::Owned(Box::new(tk::models::bpe::BPE::default())),
model: Container::Owned(Box::new(bpe)),
},
)),
}
}
}
/// WordPiece Model
#[pyclass]
#[pyclass(extends=Model)]
pub struct WordPiece {}
#[pymethods]
impl WordPiece {
/// from_files(vocab, /)
/// --
///
/// Instantiate a new WordPiece model using the provided vocabulary file
#[staticmethod]
#[new]
#[args(kwargs = "**")]
fn from_files(vocab: &str, kwargs: Option<&PyDict>) -> PyResult<Model> {
let mut builder = tk::models::wordpiece::WordPiece::from_files(vocab);
fn new(vocab: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<(Self, Model)> {
let mut builder = tk::models::wordpiece::WordPiece::builder();
if let Some(vocab) = vocab {
builder = builder.files(vocab.to_owned());
}
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
@ -242,28 +244,24 @@ impl WordPiece {
"Error while initializing WordPiece",
))
}
Ok(wordpiece) => Ok(Model {
model: Container::Owned(Box::new(wordpiece)),
}),
}
}
#[staticmethod]
fn empty() -> Model {
Ok(wordpiece) => Ok((
WordPiece {},
Model {
model: Container::Owned(Box::new(tk::models::wordpiece::WordPiece::default())),
model: Container::Owned(Box::new(wordpiece)),
},
)),
}
}
}
#[pyclass]
#[pyclass(extends=Model)]
pub struct WordLevel {}
#[pymethods]
impl WordLevel {
#[staticmethod]
#[new]
#[args(kwargs = "**")]
fn from_files(vocab: &str, kwargs: Option<&PyDict>) -> PyResult<Model> {
fn new(vocab: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<(Self, Model)> {
let mut unk_token = String::from("<unk>");
if let Some(kwargs) = kwargs {
@ -276,6 +274,7 @@ impl WordLevel {
}
}
if let Some(vocab) = vocab {
match tk::models::wordlevel::WordLevel::from_files(vocab, unk_token) {
Err(e) => {
println!("Errors: {:?}", e);
@ -283,16 +282,20 @@ impl WordLevel {
"Error while initializing WordLevel",
))
}
Ok(model) => Ok(Model {
Ok(model) => Ok((
WordLevel {},
Model {
model: Container::Owned(Box::new(model)),
}),
},
)),
}
}
#[staticmethod]
fn empty() -> Model {
} else {
Ok((
WordLevel {},
Model {
model: Container::Owned(Box::new(tk::models::wordlevel::WordLevel::default())),
},
))
}
}
}

View File

@ -16,7 +16,7 @@ pub struct BertNormalizer {}
impl BertNormalizer {
#[new]
#[args(kwargs = "**")]
fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Normalizer)> {
let mut clean_text = true;
let mut handle_chinese_chars = true;
let mut strip_accents = true;
@ -35,14 +35,17 @@ impl BertNormalizer {
}
}
Ok(obj.init(Normalizer {
Ok((
BertNormalizer {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::bert::BertNormalizer::new(
clean_text,
handle_chinese_chars,
strip_accents,
lowercase,
))),
}))
},
))
}
}
@ -51,10 +54,13 @@ pub struct NFD {}
#[pymethods]
impl NFD {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(Normalizer {
fn new() -> PyResult<(Self, Normalizer)> {
Ok((
NFD {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFD)),
}))
},
))
}
}
@ -63,10 +69,13 @@ pub struct NFKD {}
#[pymethods]
impl NFKD {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(Normalizer {
fn new() -> PyResult<(Self, Normalizer)> {
Ok((
NFKD {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKD)),
}))
},
))
}
}
@ -75,10 +84,13 @@ pub struct NFC {}
#[pymethods]
impl NFC {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(Normalizer {
fn new() -> PyResult<(Self, Normalizer)> {
Ok((
NFC {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFC)),
}))
},
))
}
}
@ -87,10 +99,13 @@ pub struct NFKC {}
#[pymethods]
impl NFKC {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(Normalizer {
fn new() -> PyResult<(Self, Normalizer)> {
Ok((
NFKC {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKC)),
}))
},
))
}
}
@ -99,11 +114,11 @@ pub struct Sequence {}
#[pymethods]
impl Sequence {
#[new]
fn new(obj: &PyRawObject, normalizers: &PyList) -> PyResult<()> {
fn new(normalizers: &PyList) -> PyResult<(Self, Normalizer)> {
let normalizers = normalizers
.iter()
.map(|n| {
let normalizer: &mut Normalizer = n.extract()?;
let mut normalizer: PyRefMut<Normalizer> = n.extract()?;
if let Some(normalizer) = normalizer.normalizer.to_pointer() {
Ok(normalizer)
} else {
@ -114,11 +129,14 @@ impl Sequence {
})
.collect::<PyResult<_>>()?;
Ok(obj.init(Normalizer {
Ok((
Sequence {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::utils::Sequence::new(
normalizers,
))),
}))
},
))
}
}
@ -127,10 +145,13 @@ pub struct Lowercase {}
#[pymethods]
impl Lowercase {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(Normalizer {
fn new() -> PyResult<(Self, Normalizer)> {
Ok((
Lowercase {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::utils::Lowercase)),
}))
},
))
}
}
@ -140,7 +161,7 @@ pub struct Strip {}
impl Strip {
#[new]
#[args(kwargs = "**")]
fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Normalizer)> {
let mut left = true;
let mut right = true;
@ -153,8 +174,13 @@ impl Strip {
}
}
Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::strip::Strip::new(left, right))),
}))
Ok((
Strip {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::strip::Strip::new(
left, right,
))),
},
))
}
}

View File

@ -38,9 +38,8 @@ pub struct ByteLevel {}
impl ByteLevel {
#[new]
#[args(kwargs = "**")]
fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PreTokenizer)> {
let mut byte_level = tk::pre_tokenizers::byte_level::ByteLevel::default();
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
@ -53,9 +52,12 @@ impl ByteLevel {
}
}
Ok(obj.init(PreTokenizer {
Ok((
ByteLevel {},
PreTokenizer {
pretok: Container::Owned(Box::new(byte_level)),
}))
},
))
}
#[staticmethod]
@ -72,10 +74,13 @@ pub struct Whitespace {}
#[pymethods]
impl Whitespace {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(PreTokenizer {
fn new() -> PyResult<(Self, PreTokenizer)> {
Ok((
Whitespace {},
PreTokenizer {
pretok: Container::Owned(Box::new(tk::pre_tokenizers::whitespace::Whitespace)),
}))
},
))
}
}
@ -84,10 +89,13 @@ pub struct WhitespaceSplit {}
#[pymethods]
impl WhitespaceSplit {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(PreTokenizer {
fn new() -> PyResult<(Self, PreTokenizer)> {
Ok((
WhitespaceSplit {},
PreTokenizer {
pretok: Container::Owned(Box::new(tk::pre_tokenizers::whitespace::WhitespaceSplit)),
}))
},
))
}
}
@ -96,18 +104,21 @@ pub struct CharDelimiterSplit {}
#[pymethods]
impl CharDelimiterSplit {
#[new]
pub fn new(obj: &PyRawObject, delimiter: &str) -> PyResult<()> {
pub fn new(delimiter: &str) -> PyResult<(Self, PreTokenizer)> {
let chr_delimiter = delimiter
.chars()
.nth(0)
.ok_or(exceptions::Exception::py_err(
"delimiter must be a single character",
))?;
Ok(obj.init(PreTokenizer {
Ok((
CharDelimiterSplit {},
PreTokenizer {
pretok: Container::Owned(Box::new(
tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(chr_delimiter),
)),
}))
},
))
}
}
@ -116,10 +127,13 @@ pub struct BertPreTokenizer {}
#[pymethods]
impl BertPreTokenizer {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(PreTokenizer {
fn new() -> PyResult<(Self, PreTokenizer)> {
Ok((
BertPreTokenizer {},
PreTokenizer {
pretok: Container::Owned(Box::new(tk::pre_tokenizers::bert::BertPreTokenizer)),
}))
},
))
}
}
@ -129,7 +143,7 @@ pub struct Metaspace {}
impl Metaspace {
#[new]
#[args(kwargs = "**")]
fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PreTokenizer)> {
let mut replacement = '▁';
let mut add_prefix_space = true;
@ -149,12 +163,15 @@ impl Metaspace {
}
}
Ok(obj.init(PreTokenizer {
Ok((
Metaspace {},
PreTokenizer {
pretok: Container::Owned(Box::new(tk::pre_tokenizers::metaspace::Metaspace::new(
replacement,
add_prefix_space,
))),
}))
},
))
}
}

View File

@ -21,12 +21,15 @@ pub struct BertProcessing {}
#[pymethods]
impl BertProcessing {
#[new]
fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> {
Ok(obj.init(PostProcessor {
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PostProcessor)> {
Ok((
BertProcessing {},
PostProcessor {
processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new(
sep, cls,
))),
}))
},
))
}
}
@ -35,12 +38,15 @@ pub struct RobertaProcessing {}
#[pymethods]
impl RobertaProcessing {
#[new]
fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> {
Ok(obj.init(PostProcessor {
processor: Container::Owned(Box::new(tk::processors::roberta::RobertaProcessing::new(
sep, cls,
))),
}))
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PostProcessor)> {
Ok((
RobertaProcessing {},
PostProcessor {
processor: Container::Owned(Box::new(
tk::processors::roberta::RobertaProcessing::new(sep, cls),
)),
},
))
}
}
@ -50,7 +56,7 @@ pub struct ByteLevel {}
impl ByteLevel {
#[new]
#[args(kwargs = "**")]
fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PostProcessor)> {
let mut byte_level = tk::processors::byte_level::ByteLevel::default();
if let Some(kwargs) = kwargs {
@ -62,8 +68,11 @@ impl ByteLevel {
}
}
}
Ok(obj.init(PostProcessor {
Ok((
ByteLevel {},
PostProcessor {
processor: Container::Owned(Box::new(byte_level)),
}))
},
))
}
}

View File

@ -28,7 +28,7 @@ pub struct AddedToken {
impl AddedToken {
#[new]
#[args(kwargs = "**")]
fn new(obj: &PyRawObject, content: &str, kwargs: Option<&PyDict>) -> PyResult<()> {
fn new(content: &str, kwargs: Option<&PyDict>) -> PyResult<Self> {
let mut token = tk::tokenizer::AddedToken::from(content.to_owned());
if let Some(kwargs) = kwargs {
@ -43,8 +43,7 @@ impl AddedToken {
}
}
obj.init({ AddedToken { token } });
Ok(())
Ok(AddedToken { token })
}
#[getter]
@ -97,11 +96,10 @@ pub struct Tokenizer {
#[pymethods]
impl Tokenizer {
#[new]
fn new(obj: &PyRawObject, model: &mut Model) -> PyResult<()> {
fn new(mut model: PyRefMut<Model>) -> PyResult<Self> {
if let Some(model) = model.model.to_pointer() {
let tokenizer = tk::tokenizer::Tokenizer::new(model);
obj.init({ Tokenizer { tokenizer } });
Ok(())
Ok(Tokenizer { tokenizer })
} else {
Err(exceptions::Exception::py_err(
"The Model is already being used in another Tokenizer",
@ -320,7 +318,7 @@ impl Tokenizer {
content,
..Default::default()
})
} else if let Ok(token) = token.cast_as::<AddedToken>() {
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone())
} else {
Err(exceptions::Exception::py_err(
@ -342,7 +340,7 @@ impl Tokenizer {
content,
..Default::default()
})
} else if let Ok(token) = token.cast_as::<AddedToken>() {
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone())
} else {
Err(exceptions::Exception::py_err(
@ -392,7 +390,7 @@ impl Tokenizer {
}
#[setter]
fn set_model(&mut self, model: &mut Model) -> PyResult<()> {
fn set_model(&mut self, mut model: PyRefMut<Model>) -> PyResult<()> {
if let Some(model) = model.model.to_pointer() {
self.tokenizer.with_model(model);
Ok(())
@ -414,7 +412,7 @@ impl Tokenizer {
}
#[setter]
fn set_normalizer(&mut self, normalizer: &mut Normalizer) -> PyResult<()> {
fn set_normalizer(&mut self, mut normalizer: PyRefMut<Normalizer>) -> PyResult<()> {
if let Some(normalizer) = normalizer.normalizer.to_pointer() {
self.tokenizer.with_normalizer(normalizer);
Ok(())
@ -436,7 +434,7 @@ impl Tokenizer {
}
#[setter]
fn set_pre_tokenizer(&mut self, pretok: &mut PreTokenizer) -> PyResult<()> {
fn set_pre_tokenizer(&mut self, mut pretok: PyRefMut<PreTokenizer>) -> PyResult<()> {
if let Some(pretok) = pretok.pretok.to_pointer() {
self.tokenizer.with_pre_tokenizer(pretok);
Ok(())
@ -458,7 +456,7 @@ impl Tokenizer {
}
#[setter]
fn set_post_processor(&mut self, processor: &mut PostProcessor) -> PyResult<()> {
fn set_post_processor(&mut self, mut processor: PyRefMut<PostProcessor>) -> PyResult<()> {
if let Some(processor) = processor.processor.to_pointer() {
self.tokenizer.with_post_processor(processor);
Ok(())
@ -477,7 +475,7 @@ impl Tokenizer {
}
#[setter]
fn set_decoder(&mut self, decoder: &mut Decoder) -> PyResult<()> {
fn set_decoder(&mut self, mut decoder: PyRefMut<Decoder>) -> PyResult<()> {
if let Some(decoder) = decoder.decoder.to_pointer() {
self.tokenizer.with_decoder(decoder);
Ok(())

View File

@ -21,7 +21,7 @@ impl BpeTrainer {
/// Create a new BpeTrainer with the given configuration
#[new]
#[args(kwargs = "**")]
pub fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Trainer)> {
let mut builder = tk::models::bpe::BpeTrainer::builder();
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
@ -40,7 +40,7 @@ impl BpeTrainer {
content,
..Default::default()
})
} else if let Ok(token) = token.cast_as::<AddedToken>() {
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone())
} else {
Err(exceptions::Exception::py_err(
@ -71,9 +71,12 @@ impl BpeTrainer {
};
}
}
Ok(obj.init(Trainer {
Ok((
BpeTrainer {},
Trainer {
trainer: Container::Owned(Box::new(builder.build())),
}))
},
))
}
}
@ -87,7 +90,7 @@ impl WordPieceTrainer {
/// Create a new BpeTrainer with the given configuration
#[new]
#[args(kwargs = "**")]
pub fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Trainer)> {
let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
@ -106,7 +109,7 @@ impl WordPieceTrainer {
content,
..Default::default()
})
} else if let Ok(token) = token.cast_as::<AddedToken>() {
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone())
} else {
Err(exceptions::Exception::py_err(
@ -138,8 +141,11 @@ impl WordPieceTrainer {
}
}
Ok(obj.init(Trainer {
Ok((
WordPieceTrainer {},
Trainer {
trainer: Container::Owned(Box::new(builder.build())),
}))
},
))
}
}