Python - PyWordLevel can get/set its attributes

This commit is contained in:
Anthony MOI
2020-11-16 14:31:49 -05:00
committed by Anthony MOI
parent 760537aad3
commit 78beae8b7d
3 changed files with 55 additions and 40 deletions

View File

@ -700,56 +700,59 @@ impl PyWordPiece {
#[text_signature = "(self, vocab, unk_token)"]
pub struct PyWordLevel {}
impl PyWordLevel {
fn get_unk(kwargs: Option<&PyDict>) -> PyResult<String> {
let mut unk_token = String::from("<unk>");
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"unk_token" => unk_token = val.extract()?,
_ => println!("Ignored unknown kwargs option {}", key),
}
}
}
Ok(unk_token)
}
}
#[pymethods]
impl PyWordLevel {
#[getter]
fn get_unk_token(self_: PyRef<Self>) -> String {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::WordLevel(ref wl) = *model {
wl.unk_token.clone()
} else {
unreachable!()
}
}
#[setter]
fn set_unk_token(self_: PyRef<Self>, unk_token: String) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordLevel(ref mut wl) = *model {
wl.unk_token = unk_token;
}
}
#[new]
#[args(kwargs = "**")]
fn new(vocab: Option<PyVocab>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
let unk_token = PyWordLevel::get_unk(kwargs)?;
#[args(unk_token = "None")]
fn new(vocab: Option<PyVocab>, unk_token: Option<String>) -> PyResult<(Self, PyModel)> {
let mut builder = WordLevel::builder();
if let Some(vocab) = vocab {
let model = match vocab {
PyVocab::Vocab(vocab) => WordLevel::builder()
.vocab(vocab)
.unk_token(unk_token)
.build()
.expect("Can only fail when loading from files"),
match vocab {
PyVocab::Vocab(vocab) => {
builder = builder.vocab(vocab);
}
PyVocab::Filename(vocab_filename) => {
deprecation_warning(
"0.9.0",
"WordLevel.__init__ will not create from files anymore, \
try `WordLevel.from_file` instead",
)?;
WordLevel::from_file(vocab_filename, unk_token).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while loading WordLevel: {}",
e
))
})?
builder = builder.files(vocab_filename.to_string());
}
};
Ok((PyWordLevel {}, model.into()))
} else {
Ok((PyWordLevel {}, WordLevel::default().into()))
}
if let Some(unk_token) = unk_token {
builder = builder.unk_token(unk_token);
}
Ok((
PyWordLevel {},
builder
.build()
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?
.into(),
))
}
/// Read a :obj:`vocab.json`
@ -790,17 +793,20 @@ impl PyWordLevel {
/// Returns:
/// :class:`~tokenizers.models.WordLevel`: And instance of WordLevel loaded from file
#[classmethod]
#[args(kwargs = "**")]
#[args(unk_token = "None")]
fn from_file(
_cls: &PyType,
py: Python,
vocab: &str,
kwargs: Option<&PyDict>,
unk_token: Option<String>,
) -> PyResult<Py<Self>> {
let vocab = WordLevel::read_file(vocab).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e))
})?;
Py::new(py, PyWordLevel::new(Some(PyVocab::Vocab(vocab)), kwargs)?)
Py::new(
py,
PyWordLevel::new(Some(PyVocab::Vocab(vocab)), unk_token)?,
)
}
}

View File

@ -120,3 +120,12 @@ class TestWordLevel:
assert isinstance(WordLevel(roberta_files["vocab"]), Model)
with pytest.deprecated_call():
assert isinstance(WordLevel(roberta_files["vocab"]), WordLevel)
def test_can_modify(self):
model = WordLevel(unk_token="<oov>")
assert model.unk_token == "<oov>"
# Modify these
model.unk_token = "<unk>"
assert model.unk_token == "<unk>"

View File

@ -107,7 +107,7 @@ impl WordLevelBuilder {
pub struct WordLevel {
vocab: HashMap<String, u32>,
vocab_r: HashMap<u32, String>,
unk_token: String,
pub unk_token: String,
}
impl std::fmt::Debug for WordLevel {