diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index b53a7731..c35ae5a9 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -700,56 +700,59 @@ impl PyWordPiece { #[text_signature = "(self, vocab, unk_token)"] pub struct PyWordLevel {} -impl PyWordLevel { - fn get_unk(kwargs: Option<&PyDict>) -> PyResult { - let mut unk_token = String::from(""); - - if let Some(kwargs) = kwargs { - for (key, val) in kwargs { - let key: &str = key.extract()?; - match key { - "unk_token" => unk_token = val.extract()?, - _ => println!("Ignored unknown kwargs option {}", key), - } - } - } - Ok(unk_token) - } -} - #[pymethods] impl PyWordLevel { + #[getter] + fn get_unk_token(self_: PyRef) -> String { + let super_ = self_.as_ref(); + let model = super_.model.read().unwrap(); + if let ModelWrapper::WordLevel(ref wl) = *model { + wl.unk_token.clone() + } else { + unreachable!() + } + } + + #[setter] + fn set_unk_token(self_: PyRef, unk_token: String) { + let super_ = self_.as_ref(); + let mut model = super_.model.write().unwrap(); + if let ModelWrapper::WordLevel(ref mut wl) = *model { + wl.unk_token = unk_token; + } + } + #[new] - #[args(kwargs = "**")] - fn new(vocab: Option, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> { - let unk_token = PyWordLevel::get_unk(kwargs)?; + #[args(unk_token = "None")] + fn new(vocab: Option, unk_token: Option) -> PyResult<(Self, PyModel)> { + let mut builder = WordLevel::builder(); if let Some(vocab) = vocab { - let model = match vocab { - PyVocab::Vocab(vocab) => WordLevel::builder() - .vocab(vocab) - .unk_token(unk_token) - .build() - .expect("Can only fail when loading from files"), + match vocab { + PyVocab::Vocab(vocab) => { + builder = builder.vocab(vocab); + } PyVocab::Filename(vocab_filename) => { deprecation_warning( "0.9.0", "WordLevel.__init__ will not create from files anymore, \ try `WordLevel.from_file` instead", )?; - WordLevel::from_file(vocab_filename, unk_token).map_err(|e| { - exceptions::PyException::new_err(format!( - "Error while loading WordLevel: {}", - e - )) - })? + builder = builder.files(vocab_filename.to_string()); } }; - - Ok((PyWordLevel {}, model.into())) - } else { - Ok((PyWordLevel {}, WordLevel::default().into())) } + if let Some(unk_token) = unk_token { + builder = builder.unk_token(unk_token); + } + + Ok(( + PyWordLevel {}, + builder + .build() + .map_err(|e| exceptions::PyException::new_err(e.to_string()))? + .into(), + )) } /// Read a :obj:`vocab.json` @@ -790,17 +793,20 @@ impl PyWordLevel { /// Returns: /// :class:`~tokenizers.models.WordLevel`: And instance of WordLevel loaded from file #[classmethod] - #[args(kwargs = "**")] + #[args(unk_token = "None")] fn from_file( _cls: &PyType, py: Python, vocab: &str, - kwargs: Option<&PyDict>, + unk_token: Option, ) -> PyResult> { let vocab = WordLevel::read_file(vocab).map_err(|e| { exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e)) })?; - Py::new(py, PyWordLevel::new(Some(PyVocab::Vocab(vocab)), kwargs)?) + Py::new( + py, + PyWordLevel::new(Some(PyVocab::Vocab(vocab)), unk_token)?, + ) } } diff --git a/bindings/python/tests/bindings/test_models.py b/bindings/python/tests/bindings/test_models.py index 3a118f60..3cdf725d 100644 --- a/bindings/python/tests/bindings/test_models.py +++ b/bindings/python/tests/bindings/test_models.py @@ -120,3 +120,12 @@ class TestWordLevel: assert isinstance(WordLevel(roberta_files["vocab"]), Model) with pytest.deprecated_call(): assert isinstance(WordLevel(roberta_files["vocab"]), WordLevel) + + def test_can_modify(self): + model = WordLevel(unk_token="") + + assert model.unk_token == "" + + # Modify these + model.unk_token = "" + assert model.unk_token == "" diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 5d1a4a1a..1ac2a119 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -107,7 +107,7 @@ impl WordLevelBuilder { pub struct WordLevel { vocab: HashMap, vocab_r: HashMap, - unk_token: String, + pub unk_token: String, } impl std::fmt::Debug for WordLevel {