mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Python - PyWordLevel can get/set its attributes
This commit is contained in:
@ -700,56 +700,59 @@ impl PyWordPiece {
|
||||
#[text_signature = "(self, vocab, unk_token)"]
|
||||
pub struct PyWordLevel {}
|
||||
|
||||
impl PyWordLevel {
|
||||
fn get_unk(kwargs: Option<&PyDict>) -> PyResult<String> {
|
||||
let mut unk_token = String::from("<unk>");
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, val) in kwargs {
|
||||
let key: &str = key.extract()?;
|
||||
match key {
|
||||
"unk_token" => unk_token = val.extract()?,
|
||||
_ => println!("Ignored unknown kwargs option {}", key),
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(unk_token)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyWordLevel {
|
||||
#[getter]
|
||||
fn get_unk_token(self_: PyRef<Self>) -> String {
|
||||
let super_ = self_.as_ref();
|
||||
let model = super_.model.read().unwrap();
|
||||
if let ModelWrapper::WordLevel(ref wl) = *model {
|
||||
wl.unk_token.clone()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_unk_token(self_: PyRef<Self>, unk_token: String) {
|
||||
let super_ = self_.as_ref();
|
||||
let mut model = super_.model.write().unwrap();
|
||||
if let ModelWrapper::WordLevel(ref mut wl) = *model {
|
||||
wl.unk_token = unk_token;
|
||||
}
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(vocab: Option<PyVocab>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
|
||||
let unk_token = PyWordLevel::get_unk(kwargs)?;
|
||||
#[args(unk_token = "None")]
|
||||
fn new(vocab: Option<PyVocab>, unk_token: Option<String>) -> PyResult<(Self, PyModel)> {
|
||||
let mut builder = WordLevel::builder();
|
||||
|
||||
if let Some(vocab) = vocab {
|
||||
let model = match vocab {
|
||||
PyVocab::Vocab(vocab) => WordLevel::builder()
|
||||
.vocab(vocab)
|
||||
.unk_token(unk_token)
|
||||
.build()
|
||||
.expect("Can only fail when loading from files"),
|
||||
match vocab {
|
||||
PyVocab::Vocab(vocab) => {
|
||||
builder = builder.vocab(vocab);
|
||||
}
|
||||
PyVocab::Filename(vocab_filename) => {
|
||||
deprecation_warning(
|
||||
"0.9.0",
|
||||
"WordLevel.__init__ will not create from files anymore, \
|
||||
try `WordLevel.from_file` instead",
|
||||
)?;
|
||||
WordLevel::from_file(vocab_filename, unk_token).map_err(|e| {
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while loading WordLevel: {}",
|
||||
e
|
||||
))
|
||||
})?
|
||||
builder = builder.files(vocab_filename.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
Ok((PyWordLevel {}, model.into()))
|
||||
} else {
|
||||
Ok((PyWordLevel {}, WordLevel::default().into()))
|
||||
}
|
||||
if let Some(unk_token) = unk_token {
|
||||
builder = builder.unk_token(unk_token);
|
||||
}
|
||||
|
||||
Ok((
|
||||
PyWordLevel {},
|
||||
builder
|
||||
.build()
|
||||
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Read a :obj:`vocab.json`
|
||||
@ -790,17 +793,20 @@ impl PyWordLevel {
|
||||
/// Returns:
|
||||
/// :class:`~tokenizers.models.WordLevel`: And instance of WordLevel loaded from file
|
||||
#[classmethod]
|
||||
#[args(kwargs = "**")]
|
||||
#[args(unk_token = "None")]
|
||||
fn from_file(
|
||||
_cls: &PyType,
|
||||
py: Python,
|
||||
vocab: &str,
|
||||
kwargs: Option<&PyDict>,
|
||||
unk_token: Option<String>,
|
||||
) -> PyResult<Py<Self>> {
|
||||
let vocab = WordLevel::read_file(vocab).map_err(|e| {
|
||||
exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e))
|
||||
})?;
|
||||
Py::new(py, PyWordLevel::new(Some(PyVocab::Vocab(vocab)), kwargs)?)
|
||||
Py::new(
|
||||
py,
|
||||
PyWordLevel::new(Some(PyVocab::Vocab(vocab)), unk_token)?,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -120,3 +120,12 @@ class TestWordLevel:
|
||||
assert isinstance(WordLevel(roberta_files["vocab"]), Model)
|
||||
with pytest.deprecated_call():
|
||||
assert isinstance(WordLevel(roberta_files["vocab"]), WordLevel)
|
||||
|
||||
def test_can_modify(self):
|
||||
model = WordLevel(unk_token="<oov>")
|
||||
|
||||
assert model.unk_token == "<oov>"
|
||||
|
||||
# Modify these
|
||||
model.unk_token = "<unk>"
|
||||
assert model.unk_token == "<unk>"
|
||||
|
@ -107,7 +107,7 @@ impl WordLevelBuilder {
|
||||
pub struct WordLevel {
|
||||
vocab: HashMap<String, u32>,
|
||||
vocab_r: HashMap<u32, String>,
|
||||
unk_token: String,
|
||||
pub unk_token: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for WordLevel {
|
||||
|
Reference in New Issue
Block a user