Python - Add bindings for new AddedToken options

This commit is contained in:
Anthony MOI
2020-03-23 16:39:39 -04:00
parent b1998da070
commit c65d53892d
11 changed files with 169 additions and 47 deletions

View File

@ -18,6 +18,34 @@ use tk::tokenizer::{
PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy,
};
#[pyclass(dict)]
pub struct AddedToken {
pub token: tk::tokenizer::AddedToken,
}
#[pymethods]
impl AddedToken {
#[new]
#[args(kwargs = "**")]
fn new(obj: &PyRawObject, content: &str, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut token = tk::tokenizer::AddedToken::from(content.to_owned());
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
match key {
"single_word" => token = token.single_word(value.extract()?),
"lstrip" => token = token.lstrip(value.extract()?),
"rstrip" => token = token.rstrip(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key),
}
}
}
obj.init({ AddedToken { token } });
Ok(())
}
}
#[pyclass(dict)]
pub struct Tokenizer {
tokenizer: tk::tokenizer::Tokenizer,
@ -256,14 +284,11 @@ impl Tokenizer {
content,
..Default::default()
})
} else if let Ok((content, single_word)) = token.extract::<(String, bool)>() {
Ok(tk::tokenizer::AddedToken {
content,
single_word,
})
} else if let Ok(token) = token.cast_as::<AddedToken>() {
Ok(token.token.clone())
} else {
Err(exceptions::Exception::py_err(
"Input must be a list[str] or list[(str, bool)]",
"Input must be a List[Union[str, AddedToken]]",
))
}
})
@ -272,7 +297,25 @@ impl Tokenizer {
Ok(self.tokenizer.add_tokens(&tokens))
}
fn add_special_tokens(&mut self, tokens: Vec<&str>) -> PyResult<usize> {
fn add_special_tokens(&mut self, tokens: &PyList) -> PyResult<usize> {
let tokens = tokens
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken {
content,
..Default::default()
})
} else if let Ok(token) = token.cast_as::<AddedToken>() {
Ok(token.token.clone())
} else {
Err(exceptions::Exception::py_err(
"Input must be a List[Union[str, AddedToken]]",
))
}
})
.collect::<PyResult<Vec<_>>>()?;
Ok(self.tokenizer.add_special_tokens(&tokens))
}