mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 12:18:20 +00:00
Address @n1t0 comments.
This commit is contained in:
@@ -110,13 +110,8 @@ class Unigram(Model):
|
||||
vocab: ('`optional`) string:
|
||||
Path to a vocabulary JSON file.
|
||||
|
||||
is_spm_file: ('`optional`) bool:
|
||||
If the file came out of sentencepiece, we need to load it differently
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __init__(
|
||||
self, vocab: Optional[str], is_spm_file: Optional[bool],
|
||||
):
|
||||
def __init__(self, vocab: Optional[str]):
|
||||
pass
|
||||
|
||||
@@ -263,41 +263,15 @@ pub struct PyUnigram {}
|
||||
#[pymethods]
|
||||
impl PyUnigram {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(vocab: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
|
||||
let mut is_spm_file = false;
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, val) in kwargs {
|
||||
let key: &str = key.extract()?;
|
||||
match key {
|
||||
"is_spm_file" => is_spm_file = val.extract()?,
|
||||
_ => println!("Ignored unknown kwargs option {}", key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn new(vocab: Option<&str>) -> PyResult<(Self, PyModel)> {
|
||||
if let Some(vocab) = vocab {
|
||||
let path = Path::new(vocab);
|
||||
if is_spm_file {
|
||||
match Unigram::load_spm(path) {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Error while initializing Unigram from spm file",
|
||||
))
|
||||
}
|
||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||
}
|
||||
} else {
|
||||
match Unigram::load(path) {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Error while initializing Unigram",
|
||||
))
|
||||
}
|
||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||
match Unigram::load(path) {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err("Error while loading Unigram"))
|
||||
}
|
||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||
}
|
||||
} else {
|
||||
Ok((
|
||||
|
||||
@@ -190,25 +190,7 @@ impl PyUnigramTrainer {
|
||||
"show_progress" => builder.show_progress(val.extract()?),
|
||||
"n_sub_iterations" => builder.n_sub_iterations(val.extract()?),
|
||||
"shrinking_factor" => builder.shrinking_factor(val.extract()?),
|
||||
"space_char" => {
|
||||
let string: String = val.extract()?;
|
||||
if string.chars().collect::<Vec<_>>().len() != 1 {
|
||||
return Err(exceptions::Exception::py_err(
|
||||
"space_char must be 1 unicode char long",
|
||||
));
|
||||
}
|
||||
builder.space_char(string.chars().next().ok_or_else(|| {
|
||||
exceptions::Exception::py_err("space_char must not be 0 width")
|
||||
})?)
|
||||
}
|
||||
"unk_token" => builder.unk_token(val.extract()?),
|
||||
"split_by_number" => builder.split_by_number(val.extract()?),
|
||||
"treat_whitespace_as_suffix" => {
|
||||
builder.treat_whitespace_as_suffix(val.extract()?)
|
||||
}
|
||||
"split_by_unicode_script" => builder.split_by_unicode_script(val.extract()?),
|
||||
"split_by_digits" => builder.split_by_digits(val.extract()?),
|
||||
"split_by_whitespace" => builder.split_by_whitespace(val.extract()?),
|
||||
"max_piece_length" => builder.max_piece_length(val.extract()?),
|
||||
"seed_size" => builder.seed_size(val.extract()?),
|
||||
"special_tokens" => builder.special_tokens(
|
||||
|
||||
Reference in New Issue
Block a user