Address @n1t0 comments.

This commit is contained in:
Nicolas Patry
2020-09-01 22:07:39 +02:00
parent d624645cf3
commit c0798acacf
10 changed files with 74 additions and 167 deletions

View File

@@ -110,13 +110,8 @@ class Unigram(Model):
vocab: ('`optional`) string:
Path to a vocabulary JSON file.
is_spm_file: ('`optional`) bool:
If the file came out of sentencepiece, we need to load it differently
"""
@staticmethod
def __init__(
self, vocab: Optional[str], is_spm_file: Optional[bool],
):
def __init__(self, vocab: Optional[str]):
pass

View File

@@ -263,41 +263,15 @@ pub struct PyUnigram {}
#[pymethods]
impl PyUnigram {
#[new]
#[args(kwargs = "**")]
fn new(vocab: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
let mut is_spm_file = false;
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"is_spm_file" => is_spm_file = val.extract()?,
_ => println!("Ignored unknown kwargs option {}", key),
}
}
}
fn new(vocab: Option<&str>) -> PyResult<(Self, PyModel)> {
if let Some(vocab) = vocab {
let path = Path::new(vocab);
if is_spm_file {
match Unigram::load_spm(path) {
Err(e) => {
println!("Errors: {:?}", e);
Err(exceptions::Exception::py_err(
"Error while initializing Unigram from spm file",
))
}
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
}
} else {
match Unigram::load(path) {
Err(e) => {
println!("Errors: {:?}", e);
Err(exceptions::Exception::py_err(
"Error while initializing Unigram",
))
}
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
match Unigram::load(path) {
Err(e) => {
println!("Errors: {:?}", e);
Err(exceptions::Exception::py_err("Error while loading Unigram"))
}
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
}
} else {
Ok((

View File

@@ -190,25 +190,7 @@ impl PyUnigramTrainer {
"show_progress" => builder.show_progress(val.extract()?),
"n_sub_iterations" => builder.n_sub_iterations(val.extract()?),
"shrinking_factor" => builder.shrinking_factor(val.extract()?),
"space_char" => {
let string: String = val.extract()?;
if string.chars().collect::<Vec<_>>().len() != 1 {
return Err(exceptions::Exception::py_err(
"space_char must be 1 unicode char long",
));
}
builder.space_char(string.chars().next().ok_or_else(|| {
exceptions::Exception::py_err("space_char must not be 0 width")
})?)
}
"unk_token" => builder.unk_token(val.extract()?),
"split_by_number" => builder.split_by_number(val.extract()?),
"treat_whitespace_as_suffix" => {
builder.treat_whitespace_as_suffix(val.extract()?)
}
"split_by_unicode_script" => builder.split_by_unicode_script(val.extract()?),
"split_by_digits" => builder.split_by_digits(val.extract()?),
"split_by_whitespace" => builder.split_by_whitespace(val.extract()?),
"max_piece_length" => builder.max_piece_length(val.extract()?),
"seed_size" => builder.seed_size(val.extract()?),
"special_tokens" => builder.special_tokens(