mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-01 14:59:20 +00:00
Python - Add bindings for new AddedToken options
This commit is contained in:
@ -18,6 +18,34 @@ use tk::tokenizer::{
|
||||
PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy,
|
||||
};
|
||||
|
||||
#[pyclass(dict)]
|
||||
pub struct AddedToken {
|
||||
pub token: tk::tokenizer::AddedToken,
|
||||
}
|
||||
#[pymethods]
|
||||
impl AddedToken {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(obj: &PyRawObject, content: &str, kwargs: Option<&PyDict>) -> PyResult<()> {
|
||||
let mut token = tk::tokenizer::AddedToken::from(content.to_owned());
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, value) in kwargs {
|
||||
let key: &str = key.extract()?;
|
||||
match key {
|
||||
"single_word" => token = token.single_word(value.extract()?),
|
||||
"lstrip" => token = token.lstrip(value.extract()?),
|
||||
"rstrip" => token = token.rstrip(value.extract()?),
|
||||
_ => println!("Ignored unknown kwarg option {}", key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
obj.init({ AddedToken { token } });
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(dict)]
|
||||
pub struct Tokenizer {
|
||||
tokenizer: tk::tokenizer::Tokenizer,
|
||||
@ -256,14 +284,11 @@ impl Tokenizer {
|
||||
content,
|
||||
..Default::default()
|
||||
})
|
||||
} else if let Ok((content, single_word)) = token.extract::<(String, bool)>() {
|
||||
Ok(tk::tokenizer::AddedToken {
|
||||
content,
|
||||
single_word,
|
||||
})
|
||||
} else if let Ok(token) = token.cast_as::<AddedToken>() {
|
||||
Ok(token.token.clone())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Input must be a list[str] or list[(str, bool)]",
|
||||
"Input must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
@ -272,7 +297,25 @@ impl Tokenizer {
|
||||
Ok(self.tokenizer.add_tokens(&tokens))
|
||||
}
|
||||
|
||||
fn add_special_tokens(&mut self, tokens: Vec<&str>) -> PyResult<usize> {
|
||||
fn add_special_tokens(&mut self, tokens: &PyList) -> PyResult<usize> {
|
||||
let tokens = tokens
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken {
|
||||
content,
|
||||
..Default::default()
|
||||
})
|
||||
} else if let Ok(token) = token.cast_as::<AddedToken>() {
|
||||
Ok(token.token.clone())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Input must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
Ok(self.tokenizer.add_special_tokens(&tokens))
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user