diff --git a/bindings/python/Cargo.lock b/bindings/python/Cargo.lock index d46a46b4..665093fa 100644 --- a/bindings/python/Cargo.lock +++ b/bindings/python/Cargo.lock @@ -73,6 +73,14 @@ dependencies = [ "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "cloudabi" +version = "0.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "console" version = "0.9.1" @@ -243,6 +251,14 @@ name = "libc" version = "0.2.66" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "lock_api" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "memchr" version = "2.2.1" @@ -278,6 +294,28 @@ name = "number_prefix" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "parking_lot" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "lock_api 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "parking_lot_core 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "parking_lot_core" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)", + "redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)", + "smallvec 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "paste" version = "0.1.6" @@ -323,27 +361,26 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.8.4" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "indoc 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", "inventory 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)", - "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)", "num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)", + "parking_lot 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)", "paste 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", - "pyo3cls 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)", + "pyo3cls 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", "regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", "serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)", - "spin 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)", "unindent 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", "version_check 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] name = "pyo3-derive-backend" -version = "0.8.4" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "proc-macro2 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)", @@ -353,11 +390,10 @@ dependencies = [ [[package]] name = "pyo3cls" -version = "0.8.4" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ - "proc-macro2 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)", - "pyo3-derive-backend 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)", + "pyo3-derive-backend 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", "quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", "syn 1.0.11 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -429,6 +465,11 @@ dependencies = [ "num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "redox_syscall" +version = "0.1.56" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "regex" version = "1.3.1" @@ -509,11 +550,6 @@ name = "smallvec" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -[[package]] -name = "spin" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" - [[package]] name = "strsim" version = "0.8.0" @@ -574,7 +610,7 @@ dependencies = [ name = "tokenizers-python" version = "0.7.0-rc3" dependencies = [ - "pyo3 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)", + "pyo3 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "tokenizers 0.9.0", ] @@ -651,6 +687,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" "checksum clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5067f5bb2d80ef5d68b4c87db81601f0b75bca627bc2ef76b141d7b846a3c6d9" "checksum clicolors-control 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90082ee5dcdd64dc4e9e0d37fbf3ee325419e39c0092191e0393df65518f741e" +"checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f" "checksum console 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f5d540c2d34ac9dd0deb5f3b5f54c36c79efa78f6b3ad19106a554d07a7b5d9f" "checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca" "checksum crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5064ebdbf05ce3cb95e45c8b086f72263f4166b29b97f6baff7ef7fe047b55ac" @@ -670,19 +707,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "501266b7edd0174f8530248f87f99c88fbe60ca4ef3dd486835b8d8d53136f7f" "checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" "checksum libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)" = "d515b1f41455adea1313a4a2ac8a8a477634fbae63cc6100e3aebb207ce61558" +"checksum lock_api 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "79b2de95ecb4691949fea4716ca53cdbcfccb2c612e19644a8bad05edcf9f47b" "checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e" "checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9" "checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4" "checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72" "checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a" +"checksum parking_lot 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "92e98c49ab0b7ce5b222f2cc9193fc4efe11c6d0bd4f648e374684a6857b1cfc" +"checksum parking_lot_core 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7582838484df45743c8434fbff785e8edf260c28748353d44bc0da32e0ceabf1" "checksum paste 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "423a519e1c6e828f1e73b720f9d9ed2fa643dce8a7737fb43235ce0b41eeaa49" "checksum paste-impl 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "4214c9e912ef61bf42b81ba9a47e8aad1b2ffaf739ab162bf96d1e011f54e6c5" "checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" "checksum proc-macro-hack 0.5.11 (registry+https://github.com/rust-lang/crates.io-index)" = "ecd45702f76d6d3c75a80564378ae228a85f0b59d2f3ed43c91b4a69eb2ebfc5" "checksum proc-macro2 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "9c9e470a8dc4aeae2dee2f335e8f533e2d4b347e1434e5671afc49b054592f27" -"checksum pyo3 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)" = "7f9df1468dddf8a59ec799cf3b930bb75ec09deabe875ba953e06c51d1077136" -"checksum pyo3-derive-backend 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)" = "9f6e56fb3e97b344a8f87d036f94578399402c6b75949de6270cd07928f790b1" -"checksum pyo3cls 0.8.4 (registry+https://github.com/rust-lang/crates.io-index)" = "97452dcdf5941627ebc5c06664a07821fc7fc88d7515f02178193a8ebe316468" +"checksum pyo3 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)" = "397fef05d982d84944edac8281972ed7108be996e5635fdfdae777d78632c8f0" +"checksum pyo3-derive-backend 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b8656b9d1a8c49838439b6e2c299ec927816c5c0c92b623f677d5579e9c04851" +"checksum pyo3cls 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ab91823d2634de5fd56af40da20fd89cac144a15da0d9101f3c2b047af9ed9ca" "checksum quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "053a8c8bcc71fcce321828dc897a98ab9760bef03a4fc36693c231e5b3216cfe" "checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412" "checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853" @@ -690,6 +730,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" "checksum rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "db6ce3297f9c85e16621bb8cca38a06779ffc31bb8184e1be4bed2be4678a098" "checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9" +"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84" "checksum regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "dc220bd33bdce8f093101afe22a037b8eb0e5af33592e6a9caafff0d4cb81cbd" "checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716" "checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" @@ -701,7 +742,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum serde_derive 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)" = "128f9e303a5a29922045a830221b8f78ec74a5f544944f3d5984f8ec3895ef64" "checksum serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)" = "48c575e0cc52bdd09b47f330f646cf59afc586e9c4e3ccd6fc1f625b8ea1dad7" "checksum smallvec 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "44e59e0c9fa00817912ae6e4e6e3c4fe04455e75699d06eedc7d85917ed8e8f4" -"checksum spin 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" "checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" "checksum syn 1.0.11 (registry+https://github.com/rust-lang/crates.io-index)" = "dff0acdb207ae2fe6d5976617f887eb1e35a2ba52c13c7234c790960cdad9238" "checksum termios 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "72b620c5ea021d75a735c943269bb07d30c9b77d6ac6b236bc8b5c496ef05625" diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 2dd697cb..8ba84aab 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -12,7 +12,7 @@ crate-type = ["cdylib"] rayon = "1.2.0" [dependencies.pyo3] -version = "0.8.4" +version = "0.9.1" features = ["extension-module"] [dependencies.tokenizers] diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 215aaff4..2ae5e91c 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -31,10 +31,13 @@ pub struct ByteLevel {} #[pymethods] impl ByteLevel { #[new] - fn new(obj: &PyRawObject) -> PyResult<()> { - Ok(obj.init(Decoder { - decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel::default())), - })) + fn new() -> PyResult<(Self, Decoder)> { + Ok(( + ByteLevel {}, + Decoder { + decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel::default())), + }, + )) } } @@ -44,7 +47,7 @@ pub struct WordPiece {} impl WordPiece { #[new] #[args(kwargs = "**")] - fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> { + fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> { let mut prefix = String::from("##"); let mut cleanup = true; @@ -57,11 +60,14 @@ impl WordPiece { } } - Ok(obj.init(Decoder { - decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new( - prefix, cleanup, - ))), - })) + Ok(( + WordPiece {}, + Decoder { + decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new( + prefix, cleanup, + ))), + }, + )) } } @@ -71,7 +77,7 @@ pub struct Metaspace {} impl Metaspace { #[new] #[args(kwargs = "**")] - fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> { + fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> { let mut replacement = '▁'; let mut add_prefix_space = true; @@ -91,12 +97,15 @@ impl Metaspace { } } - Ok(obj.init(Decoder { - decoder: Container::Owned(Box::new(tk::decoders::metaspace::Metaspace::new( - replacement, - add_prefix_space, - ))), - })) + Ok(( + Metaspace {}, + Decoder { + decoder: Container::Owned(Box::new(tk::decoders::metaspace::Metaspace::new( + replacement, + add_prefix_space, + ))), + }, + )) } } @@ -106,7 +115,7 @@ pub struct BPEDecoder {} impl BPEDecoder { #[new] #[args(kwargs = "**")] - fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> { + fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> { let mut suffix = String::from(""); if let Some(kwargs) = kwargs { @@ -119,9 +128,12 @@ impl BPEDecoder { } } - Ok(obj.init(Decoder { - decoder: Container::Owned(Box::new(tk::decoders::bpe::BPEDecoder::new(suffix))), - })) + Ok(( + BPEDecoder {}, + Decoder { + decoder: Container::Owned(Box::new(tk::decoders::bpe::BPEDecoder::new(suffix))), + }, + )) } } diff --git a/bindings/python/src/encoding.rs b/bindings/python/src/encoding.rs index aef53b33..7f2c2a7f 100644 --- a/bindings/python/src/encoding.rs +++ b/bindings/python/src/encoding.rs @@ -40,7 +40,7 @@ impl PySequenceProtocol for Encoding { impl Encoding { #[staticmethod] #[args(growing_offsets = true)] - fn merge(encodings: Vec<&Encoding>, growing_offsets: bool) -> Encoding { + fn merge(encodings: Vec>, growing_offsets: bool) -> Encoding { Encoding::new(tk::tokenizer::Encoding::merge( encodings .into_iter() diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 53c1b5a0..f268387a 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -21,7 +21,7 @@ impl EncodeInput { impl<'source> FromPyObject<'source> for EncodeInput { fn extract(ob: &'source PyAny) -> PyResult { - let sequence: &PyList = ob.downcast_ref()?; + let sequence: &PyList = ob.downcast()?; enum Mode { NoOffsets, @@ -84,7 +84,7 @@ pub struct Model { #[pymethods] impl Model { #[new] - fn new(_obj: &PyRawObject) -> PyResult<()> { + fn new() -> PyResult { Err(exceptions::Exception::py_err( "Cannot create a Model directly. Use a concrete subclass", )) @@ -142,19 +142,29 @@ impl Model { /// BPE Model /// Allows the creation of a BPE Model to be used with a Tokenizer -#[pyclass] +#[pyclass(extends=Model)] pub struct BPE {} #[pymethods] impl BPE { - /// from_files(vocab, merges, /) - /// -- - /// - /// Instanciate a new BPE model using the provided vocab and merges files - #[staticmethod] + #[new] #[args(kwargs = "**")] - fn from_files(vocab: &str, merges: &str, kwargs: Option<&PyDict>) -> PyResult { - let mut builder = tk::models::bpe::BPE::from_files(vocab, merges); + fn new( + vocab: Option<&str>, + merges: Option<&str>, + kwargs: Option<&PyDict>, + ) -> PyResult<(Self, Model)> { + let mut builder = tk::models::bpe::BPE::builder(); + if let Some(vocab) = vocab { + if let Some(merges) = merges { + builder = builder.files(vocab.to_owned(), merges.to_owned()); + } else { + return Err(exceptions::Exception::py_err(format!( + "Got vocab file ({}), but missing merges", + vocab + ))); + } + } if let Some(kwargs) = kwargs { for (key, value) in kwargs { let key: &str = key.extract()?; @@ -184,38 +194,30 @@ impl BPE { "Error while initializing BPE: {}", e ))), - Ok(bpe) => Ok(Model { - model: Container::Owned(Box::new(bpe)), - }), - } - } - - /// empty() - /// -- - /// - /// Instanciate a new BPE model with empty vocab and merges - #[staticmethod] - fn empty() -> Model { - Model { - model: Container::Owned(Box::new(tk::models::bpe::BPE::default())), + Ok(bpe) => Ok(( + BPE {}, + Model { + model: Container::Owned(Box::new(bpe)), + }, + )), } } } /// WordPiece Model -#[pyclass] +#[pyclass(extends=Model)] pub struct WordPiece {} #[pymethods] impl WordPiece { - /// from_files(vocab, /) - /// -- - /// - /// Instantiate a new WordPiece model using the provided vocabulary file - #[staticmethod] + #[new] #[args(kwargs = "**")] - fn from_files(vocab: &str, kwargs: Option<&PyDict>) -> PyResult { - let mut builder = tk::models::wordpiece::WordPiece::from_files(vocab); + fn new(vocab: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<(Self, Model)> { + let mut builder = tk::models::wordpiece::WordPiece::builder(); + + if let Some(vocab) = vocab { + builder = builder.files(vocab.to_owned()); + } if let Some(kwargs) = kwargs { for (key, val) in kwargs { @@ -242,28 +244,24 @@ impl WordPiece { "Error while initializing WordPiece", )) } - Ok(wordpiece) => Ok(Model { - model: Container::Owned(Box::new(wordpiece)), - }), - } - } - - #[staticmethod] - fn empty() -> Model { - Model { - model: Container::Owned(Box::new(tk::models::wordpiece::WordPiece::default())), + Ok(wordpiece) => Ok(( + WordPiece {}, + Model { + model: Container::Owned(Box::new(wordpiece)), + }, + )), } } } -#[pyclass] +#[pyclass(extends=Model)] pub struct WordLevel {} #[pymethods] impl WordLevel { - #[staticmethod] + #[new] #[args(kwargs = "**")] - fn from_files(vocab: &str, kwargs: Option<&PyDict>) -> PyResult { + fn new(vocab: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<(Self, Model)> { let mut unk_token = String::from(""); if let Some(kwargs) = kwargs { @@ -276,23 +274,28 @@ impl WordLevel { } } - match tk::models::wordlevel::WordLevel::from_files(vocab, unk_token) { - Err(e) => { - println!("Errors: {:?}", e); - Err(exceptions::Exception::py_err( - "Error while initializing WordLevel", - )) + if let Some(vocab) = vocab { + match tk::models::wordlevel::WordLevel::from_files(vocab, unk_token) { + Err(e) => { + println!("Errors: {:?}", e); + Err(exceptions::Exception::py_err( + "Error while initializing WordLevel", + )) + } + Ok(model) => Ok(( + WordLevel {}, + Model { + model: Container::Owned(Box::new(model)), + }, + )), } - Ok(model) => Ok(Model { - model: Container::Owned(Box::new(model)), - }), - } - } - - #[staticmethod] - fn empty() -> Model { - Model { - model: Container::Owned(Box::new(tk::models::wordlevel::WordLevel::default())), + } else { + Ok(( + WordLevel {}, + Model { + model: Container::Owned(Box::new(tk::models::wordlevel::WordLevel::default())), + }, + )) } } } diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index fe3c6e6c..07aa97b1 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -16,7 +16,7 @@ pub struct BertNormalizer {} impl BertNormalizer { #[new] #[args(kwargs = "**")] - fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> { + fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Normalizer)> { let mut clean_text = true; let mut handle_chinese_chars = true; let mut strip_accents = true; @@ -35,14 +35,17 @@ impl BertNormalizer { } } - Ok(obj.init(Normalizer { - normalizer: Container::Owned(Box::new(tk::normalizers::bert::BertNormalizer::new( - clean_text, - handle_chinese_chars, - strip_accents, - lowercase, - ))), - })) + Ok(( + BertNormalizer {}, + Normalizer { + normalizer: Container::Owned(Box::new(tk::normalizers::bert::BertNormalizer::new( + clean_text, + handle_chinese_chars, + strip_accents, + lowercase, + ))), + }, + )) } } @@ -51,10 +54,13 @@ pub struct NFD {} #[pymethods] impl NFD { #[new] - fn new(obj: &PyRawObject) -> PyResult<()> { - Ok(obj.init(Normalizer { - normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFD)), - })) + fn new() -> PyResult<(Self, Normalizer)> { + Ok(( + NFD {}, + Normalizer { + normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFD)), + }, + )) } } @@ -63,10 +69,13 @@ pub struct NFKD {} #[pymethods] impl NFKD { #[new] - fn new(obj: &PyRawObject) -> PyResult<()> { - Ok(obj.init(Normalizer { - normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKD)), - })) + fn new() -> PyResult<(Self, Normalizer)> { + Ok(( + NFKD {}, + Normalizer { + normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKD)), + }, + )) } } @@ -75,10 +84,13 @@ pub struct NFC {} #[pymethods] impl NFC { #[new] - fn new(obj: &PyRawObject) -> PyResult<()> { - Ok(obj.init(Normalizer { - normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFC)), - })) + fn new() -> PyResult<(Self, Normalizer)> { + Ok(( + NFC {}, + Normalizer { + normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFC)), + }, + )) } } @@ -87,10 +99,13 @@ pub struct NFKC {} #[pymethods] impl NFKC { #[new] - fn new(obj: &PyRawObject) -> PyResult<()> { - Ok(obj.init(Normalizer { - normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKC)), - })) + fn new() -> PyResult<(Self, Normalizer)> { + Ok(( + NFKC {}, + Normalizer { + normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKC)), + }, + )) } } @@ -99,11 +114,11 @@ pub struct Sequence {} #[pymethods] impl Sequence { #[new] - fn new(obj: &PyRawObject, normalizers: &PyList) -> PyResult<()> { + fn new(normalizers: &PyList) -> PyResult<(Self, Normalizer)> { let normalizers = normalizers .iter() .map(|n| { - let normalizer: &mut Normalizer = n.extract()?; + let mut normalizer: PyRefMut = n.extract()?; if let Some(normalizer) = normalizer.normalizer.to_pointer() { Ok(normalizer) } else { @@ -114,11 +129,14 @@ impl Sequence { }) .collect::>()?; - Ok(obj.init(Normalizer { - normalizer: Container::Owned(Box::new(tk::normalizers::utils::Sequence::new( - normalizers, - ))), - })) + Ok(( + Sequence {}, + Normalizer { + normalizer: Container::Owned(Box::new(tk::normalizers::utils::Sequence::new( + normalizers, + ))), + }, + )) } } @@ -127,10 +145,13 @@ pub struct Lowercase {} #[pymethods] impl Lowercase { #[new] - fn new(obj: &PyRawObject) -> PyResult<()> { - Ok(obj.init(Normalizer { - normalizer: Container::Owned(Box::new(tk::normalizers::utils::Lowercase)), - })) + fn new() -> PyResult<(Self, Normalizer)> { + Ok(( + Lowercase {}, + Normalizer { + normalizer: Container::Owned(Box::new(tk::normalizers::utils::Lowercase)), + }, + )) } } @@ -140,7 +161,7 @@ pub struct Strip {} impl Strip { #[new] #[args(kwargs = "**")] - fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> { + fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Normalizer)> { let mut left = true; let mut right = true; @@ -153,8 +174,13 @@ impl Strip { } } - Ok(obj.init(Normalizer { - normalizer: Container::Owned(Box::new(tk::normalizers::strip::Strip::new(left, right))), - })) + Ok(( + Strip {}, + Normalizer { + normalizer: Container::Owned(Box::new(tk::normalizers::strip::Strip::new( + left, right, + ))), + }, + )) } } diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index af48b0b1..27fe6fd2 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -38,9 +38,8 @@ pub struct ByteLevel {} impl ByteLevel { #[new] #[args(kwargs = "**")] - fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> { + fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PreTokenizer)> { let mut byte_level = tk::pre_tokenizers::byte_level::ByteLevel::default(); - if let Some(kwargs) = kwargs { for (key, value) in kwargs { let key: &str = key.extract()?; @@ -53,9 +52,12 @@ impl ByteLevel { } } - Ok(obj.init(PreTokenizer { - pretok: Container::Owned(Box::new(byte_level)), - })) + Ok(( + ByteLevel {}, + PreTokenizer { + pretok: Container::Owned(Box::new(byte_level)), + }, + )) } #[staticmethod] @@ -72,10 +74,13 @@ pub struct Whitespace {} #[pymethods] impl Whitespace { #[new] - fn new(obj: &PyRawObject) -> PyResult<()> { - Ok(obj.init(PreTokenizer { - pretok: Container::Owned(Box::new(tk::pre_tokenizers::whitespace::Whitespace)), - })) + fn new() -> PyResult<(Self, PreTokenizer)> { + Ok(( + Whitespace {}, + PreTokenizer { + pretok: Container::Owned(Box::new(tk::pre_tokenizers::whitespace::Whitespace)), + }, + )) } } @@ -84,10 +89,13 @@ pub struct WhitespaceSplit {} #[pymethods] impl WhitespaceSplit { #[new] - fn new(obj: &PyRawObject) -> PyResult<()> { - Ok(obj.init(PreTokenizer { - pretok: Container::Owned(Box::new(tk::pre_tokenizers::whitespace::WhitespaceSplit)), - })) + fn new() -> PyResult<(Self, PreTokenizer)> { + Ok(( + WhitespaceSplit {}, + PreTokenizer { + pretok: Container::Owned(Box::new(tk::pre_tokenizers::whitespace::WhitespaceSplit)), + }, + )) } } @@ -96,18 +104,21 @@ pub struct CharDelimiterSplit {} #[pymethods] impl CharDelimiterSplit { #[new] - pub fn new(obj: &PyRawObject, delimiter: &str) -> PyResult<()> { + pub fn new(delimiter: &str) -> PyResult<(Self, PreTokenizer)> { let chr_delimiter = delimiter .chars() .nth(0) .ok_or(exceptions::Exception::py_err( "delimiter must be a single character", ))?; - Ok(obj.init(PreTokenizer { - pretok: Container::Owned(Box::new( - tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(chr_delimiter), - )), - })) + Ok(( + CharDelimiterSplit {}, + PreTokenizer { + pretok: Container::Owned(Box::new( + tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(chr_delimiter), + )), + }, + )) } } @@ -116,10 +127,13 @@ pub struct BertPreTokenizer {} #[pymethods] impl BertPreTokenizer { #[new] - fn new(obj: &PyRawObject) -> PyResult<()> { - Ok(obj.init(PreTokenizer { - pretok: Container::Owned(Box::new(tk::pre_tokenizers::bert::BertPreTokenizer)), - })) + fn new() -> PyResult<(Self, PreTokenizer)> { + Ok(( + BertPreTokenizer {}, + PreTokenizer { + pretok: Container::Owned(Box::new(tk::pre_tokenizers::bert::BertPreTokenizer)), + }, + )) } } @@ -129,7 +143,7 @@ pub struct Metaspace {} impl Metaspace { #[new] #[args(kwargs = "**")] - fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> { + fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PreTokenizer)> { let mut replacement = '▁'; let mut add_prefix_space = true; @@ -149,12 +163,15 @@ impl Metaspace { } } - Ok(obj.init(PreTokenizer { - pretok: Container::Owned(Box::new(tk::pre_tokenizers::metaspace::Metaspace::new( - replacement, - add_prefix_space, - ))), - })) + Ok(( + Metaspace {}, + PreTokenizer { + pretok: Container::Owned(Box::new(tk::pre_tokenizers::metaspace::Metaspace::new( + replacement, + add_prefix_space, + ))), + }, + )) } } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 4ef1ab78..e63175e8 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -21,12 +21,15 @@ pub struct BertProcessing {} #[pymethods] impl BertProcessing { #[new] - fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> { - Ok(obj.init(PostProcessor { - processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new( - sep, cls, - ))), - })) + fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PostProcessor)> { + Ok(( + BertProcessing {}, + PostProcessor { + processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new( + sep, cls, + ))), + }, + )) } } @@ -35,12 +38,15 @@ pub struct RobertaProcessing {} #[pymethods] impl RobertaProcessing { #[new] - fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> { - Ok(obj.init(PostProcessor { - processor: Container::Owned(Box::new(tk::processors::roberta::RobertaProcessing::new( - sep, cls, - ))), - })) + fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PostProcessor)> { + Ok(( + RobertaProcessing {}, + PostProcessor { + processor: Container::Owned(Box::new( + tk::processors::roberta::RobertaProcessing::new(sep, cls), + )), + }, + )) } } @@ -50,7 +56,7 @@ pub struct ByteLevel {} impl ByteLevel { #[new] #[args(kwargs = "**")] - fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> { + fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PostProcessor)> { let mut byte_level = tk::processors::byte_level::ByteLevel::default(); if let Some(kwargs) = kwargs { @@ -62,8 +68,11 @@ impl ByteLevel { } } } - Ok(obj.init(PostProcessor { - processor: Container::Owned(Box::new(byte_level)), - })) + Ok(( + ByteLevel {}, + PostProcessor { + processor: Container::Owned(Box::new(byte_level)), + }, + )) } } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 09805ae8..b7363261 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -28,7 +28,7 @@ pub struct AddedToken { impl AddedToken { #[new] #[args(kwargs = "**")] - fn new(obj: &PyRawObject, content: &str, kwargs: Option<&PyDict>) -> PyResult<()> { + fn new(content: &str, kwargs: Option<&PyDict>) -> PyResult { let mut token = tk::tokenizer::AddedToken::from(content.to_owned()); if let Some(kwargs) = kwargs { @@ -43,8 +43,7 @@ impl AddedToken { } } - obj.init({ AddedToken { token } }); - Ok(()) + Ok(AddedToken { token }) } #[getter] @@ -97,11 +96,10 @@ pub struct Tokenizer { #[pymethods] impl Tokenizer { #[new] - fn new(obj: &PyRawObject, model: &mut Model) -> PyResult<()> { + fn new(mut model: PyRefMut) -> PyResult { if let Some(model) = model.model.to_pointer() { let tokenizer = tk::tokenizer::Tokenizer::new(model); - obj.init({ Tokenizer { tokenizer } }); - Ok(()) + Ok(Tokenizer { tokenizer }) } else { Err(exceptions::Exception::py_err( "The Model is already being used in another Tokenizer", @@ -320,7 +318,7 @@ impl Tokenizer { content, ..Default::default() }) - } else if let Ok(token) = token.cast_as::() { + } else if let Ok(token) = token.extract::>() { Ok(token.token.clone()) } else { Err(exceptions::Exception::py_err( @@ -342,7 +340,7 @@ impl Tokenizer { content, ..Default::default() }) - } else if let Ok(token) = token.cast_as::() { + } else if let Ok(token) = token.extract::>() { Ok(token.token.clone()) } else { Err(exceptions::Exception::py_err( @@ -392,7 +390,7 @@ impl Tokenizer { } #[setter] - fn set_model(&mut self, model: &mut Model) -> PyResult<()> { + fn set_model(&mut self, mut model: PyRefMut) -> PyResult<()> { if let Some(model) = model.model.to_pointer() { self.tokenizer.with_model(model); Ok(()) @@ -414,7 +412,7 @@ impl Tokenizer { } #[setter] - fn set_normalizer(&mut self, normalizer: &mut Normalizer) -> PyResult<()> { + fn set_normalizer(&mut self, mut normalizer: PyRefMut) -> PyResult<()> { if let Some(normalizer) = normalizer.normalizer.to_pointer() { self.tokenizer.with_normalizer(normalizer); Ok(()) @@ -436,7 +434,7 @@ impl Tokenizer { } #[setter] - fn set_pre_tokenizer(&mut self, pretok: &mut PreTokenizer) -> PyResult<()> { + fn set_pre_tokenizer(&mut self, mut pretok: PyRefMut) -> PyResult<()> { if let Some(pretok) = pretok.pretok.to_pointer() { self.tokenizer.with_pre_tokenizer(pretok); Ok(()) @@ -458,7 +456,7 @@ impl Tokenizer { } #[setter] - fn set_post_processor(&mut self, processor: &mut PostProcessor) -> PyResult<()> { + fn set_post_processor(&mut self, mut processor: PyRefMut) -> PyResult<()> { if let Some(processor) = processor.processor.to_pointer() { self.tokenizer.with_post_processor(processor); Ok(()) @@ -477,7 +475,7 @@ impl Tokenizer { } #[setter] - fn set_decoder(&mut self, decoder: &mut Decoder) -> PyResult<()> { + fn set_decoder(&mut self, mut decoder: PyRefMut) -> PyResult<()> { if let Some(decoder) = decoder.decoder.to_pointer() { self.tokenizer.with_decoder(decoder); Ok(()) diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 019cf74e..76a2528e 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -21,7 +21,7 @@ impl BpeTrainer { /// Create a new BpeTrainer with the given configuration #[new] #[args(kwargs = "**")] - pub fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> { + pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Trainer)> { let mut builder = tk::models::bpe::BpeTrainer::builder(); if let Some(kwargs) = kwargs { for (key, val) in kwargs { @@ -40,7 +40,7 @@ impl BpeTrainer { content, ..Default::default() }) - } else if let Ok(token) = token.cast_as::() { + } else if let Ok(token) = token.extract::>() { Ok(token.token.clone()) } else { Err(exceptions::Exception::py_err( @@ -71,9 +71,12 @@ impl BpeTrainer { }; } } - Ok(obj.init(Trainer { - trainer: Container::Owned(Box::new(builder.build())), - })) + Ok(( + BpeTrainer {}, + Trainer { + trainer: Container::Owned(Box::new(builder.build())), + }, + )) } } @@ -87,7 +90,7 @@ impl WordPieceTrainer { /// Create a new BpeTrainer with the given configuration #[new] #[args(kwargs = "**")] - pub fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> { + pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Trainer)> { let mut builder = tk::models::wordpiece::WordPieceTrainer::builder(); if let Some(kwargs) = kwargs { for (key, val) in kwargs { @@ -106,7 +109,7 @@ impl WordPieceTrainer { content, ..Default::default() }) - } else if let Ok(token) = token.cast_as::() { + } else if let Ok(token) = token.extract::>() { Ok(token.token.clone()) } else { Err(exceptions::Exception::py_err( @@ -138,8 +141,11 @@ impl WordPieceTrainer { } } - Ok(obj.init(Trainer { - trainer: Container::Owned(Box::new(builder.build())), - })) + Ok(( + WordPieceTrainer {}, + Trainer { + trainer: Container::Owned(Box::new(builder.build())), + }, + )) } }