update init and src for bingings python

This commit is contained in:
Arthur Zucker
2023-09-01 21:07:01 +00:00
parent 587748ab09
commit 9f0c703f03
2 changed files with 28 additions and 12 deletions

View File

@ -28,9 +28,10 @@ class AddedToken:
normalized (:obj:`bool`, defaults to :obj:`True` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`):
Defines whether this token should match against the normalized version of the input
text. For example, with the added token ``"yesterday"``, and a normalizer in charge of
lowercasing the text, the token could be extract from the input ``"I saw a lion
Yesterday"``.
lowercasing the text, the token could be extracted from the input ``"I saw a lion
Yesterday"``.
special (:obj:`bool`, defaults to :obj:`False` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`):
Defines whether this token should be skipped when decoding.
"""
def __init__(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True):
@ -64,6 +65,12 @@ class AddedToken:
"""
Get the value of the :obj:`single_word` option
"""
pass
@property
def special(self):
"""
Get the value of the :obj:`special` option
"""
pass
class Encoding:

View File

@ -59,17 +59,17 @@ use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
#[pyclass(dict, module = "tokenizers", name = "AddedToken")]
pub struct PyAddedToken {
pub content: String,
pub is_special_token: bool,
pub special: bool,
pub single_word: Option<bool>,
pub lstrip: Option<bool>,
pub rstrip: Option<bool>,
pub normalized: Option<bool>,
}
impl PyAddedToken {
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
pub fn from<S: Into<String>>(content: S, special: Option<bool>) -> Self {
Self {
content: content.into(),
is_special_token: is_special_token.unwrap_or(false),
special: special.unwrap_or(false),
single_word: None,
lstrip: None,
rstrip: None,
@ -78,7 +78,7 @@ impl PyAddedToken {
}
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
let mut token = tk::AddedToken::from(&self.content, self.special);
if let Some(sw) = self.single_word {
token = token.single_word(sw);
@ -105,6 +105,7 @@ impl PyAddedToken {
dict.set_item("lstrip", token.lstrip)?;
dict.set_item("rstrip", token.rstrip)?;
dict.set_item("normalized", token.normalized)?;
dict.set_item("special", token.special)?;
Ok(dict)
}
@ -118,7 +119,7 @@ impl From<tk::AddedToken> for PyAddedToken {
lstrip: Some(token.lstrip),
rstrip: Some(token.rstrip),
normalized: Some(token.normalized),
is_special_token: !token.normalized,
special: Some(token.special),
}
}
}
@ -138,6 +139,7 @@ impl PyAddedToken {
"lstrip" => token.lstrip = Some(value.extract()?),
"rstrip" => token.rstrip = Some(value.extract()?),
"normalized" => token.normalized = Some(value.extract()?),
"special" => token.special = Some(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key),
}
}
@ -161,6 +163,7 @@ impl PyAddedToken {
"lstrip" => self.lstrip = Some(value.extract()?),
"rstrip" => self.rstrip = Some(value.extract()?),
"normalized" => self.normalized = Some(value.extract()?),
"special" => self.special = Some(value.extract()?),
_ => {}
}
}
@ -199,6 +202,11 @@ impl PyAddedToken {
fn get_normalized(&self) -> bool {
self.get_token().normalized
}
/// Get the value of the :obj:`special` option
#[getter]
fn get_special(&self) -> bool {
self.get_token().special
}
fn __str__(&self) -> PyResult<&str> {
Ok(&self.content)
@ -212,12 +220,13 @@ impl PyAddedToken {
let token = self.get_token();
Ok(format!(
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})",
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={}, special={})",
self.content,
bool_to_python(token.rstrip),
bool_to_python(token.lstrip),
bool_to_python(token.single_word),
bool_to_python(token.normalized)
bool_to_python(token.normalized),
bool_to_python(token.special)
))
}
@ -1090,7 +1099,7 @@ impl PyTokenizer {
if let Ok(content) = token.extract::<String>() {
Ok(PyAddedToken::from(content, Some(false)).get_token())
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = false;
token.special = false;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
@ -1127,7 +1136,7 @@ impl PyTokenizer {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
token.special = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(