Rust - Add AddedVocabulary + normalized option on AddedToken

This commit is contained in:
Anthony MOI
2020-06-15 22:46:30 -04:00
parent 7dff86b704
commit 397cc539da
7 changed files with 562 additions and 74 deletions

View File

@@ -28,8 +28,8 @@ pub struct AddedToken {
impl AddedToken {
#[new]
#[args(kwargs = "**")]
fn new(content: &str, kwargs: Option<&PyDict>) -> PyResult<Self> {
let mut token = tk::tokenizer::AddedToken::from(content.to_owned());
fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> {
let mut token = tk::tokenizer::AddedToken::from(content.to_owned(), is_special_token);
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
@@ -38,6 +38,7 @@ impl AddedToken {
"single_word" => token = token.single_word(value.extract()?),
"lstrip" => token = token.lstrip(value.extract()?),
"rstrip" => token = token.rstrip(value.extract()?),
"normalized" => token = token.normalized(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key),
}
}
@@ -65,6 +66,11 @@ impl AddedToken {
fn get_single_word(&self) -> bool {
self.token.single_word
}
#[getter]
fn get_normalized(&self) -> bool {
self.token.normalized
}
}
#[pyproto]
impl PyObjectProtocol for AddedToken {
@@ -533,7 +539,7 @@ impl Tokenizer {
self.tokenizer.token_to_id(token)
}
fn id_to_token(&self, id: u32) -> Option<String> {
fn id_to_token(&self, id: u32) -> Option<&str> {
self.tokenizer.id_to_token(id)
}
@@ -542,10 +548,7 @@ impl Tokenizer {
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken {
content,
..Default::default()
})
Ok(tk::tokenizer::AddedToken::from(content, false))
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone())
} else {
@@ -564,10 +567,7 @@ impl Tokenizer {
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken {
content,
..Default::default()
})
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone())
} else {