Python - PyWordLevel can get/set its attributes

This commit is contained in:
Anthony MOI
2020-11-16 14:31:49 -05:00
committed by Anthony MOI
parent 760537aad3
commit 78beae8b7d
3 changed files with 55 additions and 40 deletions

View File

@ -700,56 +700,59 @@ impl PyWordPiece {
#[text_signature = "(self, vocab, unk_token)"] #[text_signature = "(self, vocab, unk_token)"]
pub struct PyWordLevel {} pub struct PyWordLevel {}
impl PyWordLevel {
fn get_unk(kwargs: Option<&PyDict>) -> PyResult<String> {
let mut unk_token = String::from("<unk>");
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"unk_token" => unk_token = val.extract()?,
_ => println!("Ignored unknown kwargs option {}", key),
}
}
}
Ok(unk_token)
}
}
#[pymethods] #[pymethods]
impl PyWordLevel { impl PyWordLevel {
#[getter]
fn get_unk_token(self_: PyRef<Self>) -> String {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::WordLevel(ref wl) = *model {
wl.unk_token.clone()
} else {
unreachable!()
}
}
#[setter]
fn set_unk_token(self_: PyRef<Self>, unk_token: String) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordLevel(ref mut wl) = *model {
wl.unk_token = unk_token;
}
}
#[new] #[new]
#[args(kwargs = "**")] #[args(unk_token = "None")]
fn new(vocab: Option<PyVocab>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> { fn new(vocab: Option<PyVocab>, unk_token: Option<String>) -> PyResult<(Self, PyModel)> {
let unk_token = PyWordLevel::get_unk(kwargs)?; let mut builder = WordLevel::builder();
if let Some(vocab) = vocab { if let Some(vocab) = vocab {
let model = match vocab { match vocab {
PyVocab::Vocab(vocab) => WordLevel::builder() PyVocab::Vocab(vocab) => {
.vocab(vocab) builder = builder.vocab(vocab);
.unk_token(unk_token) }
.build()
.expect("Can only fail when loading from files"),
PyVocab::Filename(vocab_filename) => { PyVocab::Filename(vocab_filename) => {
deprecation_warning( deprecation_warning(
"0.9.0", "0.9.0",
"WordLevel.__init__ will not create from files anymore, \ "WordLevel.__init__ will not create from files anymore, \
try `WordLevel.from_file` instead", try `WordLevel.from_file` instead",
)?; )?;
WordLevel::from_file(vocab_filename, unk_token).map_err(|e| { builder = builder.files(vocab_filename.to_string());
exceptions::PyException::new_err(format!(
"Error while loading WordLevel: {}",
e
))
})?
} }
}; };
Ok((PyWordLevel {}, model.into()))
} else {
Ok((PyWordLevel {}, WordLevel::default().into()))
} }
if let Some(unk_token) = unk_token {
builder = builder.unk_token(unk_token);
}
Ok((
PyWordLevel {},
builder
.build()
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?
.into(),
))
} }
/// Read a :obj:`vocab.json` /// Read a :obj:`vocab.json`
@ -790,17 +793,20 @@ impl PyWordLevel {
/// Returns: /// Returns:
/// :class:`~tokenizers.models.WordLevel`: And instance of WordLevel loaded from file /// :class:`~tokenizers.models.WordLevel`: And instance of WordLevel loaded from file
#[classmethod] #[classmethod]
#[args(kwargs = "**")] #[args(unk_token = "None")]
fn from_file( fn from_file(
_cls: &PyType, _cls: &PyType,
py: Python, py: Python,
vocab: &str, vocab: &str,
kwargs: Option<&PyDict>, unk_token: Option<String>,
) -> PyResult<Py<Self>> { ) -> PyResult<Py<Self>> {
let vocab = WordLevel::read_file(vocab).map_err(|e| { let vocab = WordLevel::read_file(vocab).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e)) exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e))
})?; })?;
Py::new(py, PyWordLevel::new(Some(PyVocab::Vocab(vocab)), kwargs)?) Py::new(
py,
PyWordLevel::new(Some(PyVocab::Vocab(vocab)), unk_token)?,
)
} }
} }

View File

@ -120,3 +120,12 @@ class TestWordLevel:
assert isinstance(WordLevel(roberta_files["vocab"]), Model) assert isinstance(WordLevel(roberta_files["vocab"]), Model)
with pytest.deprecated_call(): with pytest.deprecated_call():
assert isinstance(WordLevel(roberta_files["vocab"]), WordLevel) assert isinstance(WordLevel(roberta_files["vocab"]), WordLevel)
def test_can_modify(self):
model = WordLevel(unk_token="<oov>")
assert model.unk_token == "<oov>"
# Modify these
model.unk_token = "<unk>"
assert model.unk_token == "<unk>"

View File

@ -107,7 +107,7 @@ impl WordLevelBuilder {
pub struct WordLevel { pub struct WordLevel {
vocab: HashMap<String, u32>, vocab: HashMap<String, u32>,
vocab_r: HashMap<u32, String>, vocab_r: HashMap<u32, String>,
unk_token: String, pub unk_token: String,
} }
impl std::fmt::Debug for WordLevel { impl std::fmt::Debug for WordLevel {