update init and src for bingings python

This commit is contained in:
Arthur Zucker
2023-09-01 21:07:01 +00:00
parent 587748ab09
commit 9f0c703f03
2 changed files with 28 additions and 12 deletions

View File

@ -28,9 +28,10 @@ class AddedToken:
normalized (:obj:`bool`, defaults to :obj:`True` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`): normalized (:obj:`bool`, defaults to :obj:`True` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`):
Defines whether this token should match against the normalized version of the input Defines whether this token should match against the normalized version of the input
text. For example, with the added token ``"yesterday"``, and a normalizer in charge of text. For example, with the added token ``"yesterday"``, and a normalizer in charge of
lowercasing the text, the token could be extract from the input ``"I saw a lion lowercasing the text, the token could be extracted from the input ``"I saw a lion
Yesterday"``. Yesterday"``.
special (:obj:`bool`, defaults to :obj:`False` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`):
Defines whether this token should be skipped when decoding.
""" """
def __init__(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True): def __init__(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True):
@ -65,6 +66,12 @@ class AddedToken:
Get the value of the :obj:`single_word` option Get the value of the :obj:`single_word` option
""" """
pass pass
@property
def special(self):
"""
Get the value of the :obj:`special` option
"""
pass
class Encoding: class Encoding:
""" """

View File

@ -59,17 +59,17 @@ use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
#[pyclass(dict, module = "tokenizers", name = "AddedToken")] #[pyclass(dict, module = "tokenizers", name = "AddedToken")]
pub struct PyAddedToken { pub struct PyAddedToken {
pub content: String, pub content: String,
pub is_special_token: bool, pub special: bool,
pub single_word: Option<bool>, pub single_word: Option<bool>,
pub lstrip: Option<bool>, pub lstrip: Option<bool>,
pub rstrip: Option<bool>, pub rstrip: Option<bool>,
pub normalized: Option<bool>, pub normalized: Option<bool>,
} }
impl PyAddedToken { impl PyAddedToken {
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self { pub fn from<S: Into<String>>(content: S, special: Option<bool>) -> Self {
Self { Self {
content: content.into(), content: content.into(),
is_special_token: is_special_token.unwrap_or(false), special: special.unwrap_or(false),
single_word: None, single_word: None,
lstrip: None, lstrip: None,
rstrip: None, rstrip: None,
@ -78,7 +78,7 @@ impl PyAddedToken {
} }
pub fn get_token(&self) -> tk::tokenizer::AddedToken { pub fn get_token(&self) -> tk::tokenizer::AddedToken {
let mut token = tk::AddedToken::from(&self.content, self.is_special_token); let mut token = tk::AddedToken::from(&self.content, self.special);
if let Some(sw) = self.single_word { if let Some(sw) = self.single_word {
token = token.single_word(sw); token = token.single_word(sw);
@ -105,6 +105,7 @@ impl PyAddedToken {
dict.set_item("lstrip", token.lstrip)?; dict.set_item("lstrip", token.lstrip)?;
dict.set_item("rstrip", token.rstrip)?; dict.set_item("rstrip", token.rstrip)?;
dict.set_item("normalized", token.normalized)?; dict.set_item("normalized", token.normalized)?;
dict.set_item("special", token.special)?;
Ok(dict) Ok(dict)
} }
@ -118,7 +119,7 @@ impl From<tk::AddedToken> for PyAddedToken {
lstrip: Some(token.lstrip), lstrip: Some(token.lstrip),
rstrip: Some(token.rstrip), rstrip: Some(token.rstrip),
normalized: Some(token.normalized), normalized: Some(token.normalized),
is_special_token: !token.normalized, special: Some(token.special),
} }
} }
} }
@ -138,6 +139,7 @@ impl PyAddedToken {
"lstrip" => token.lstrip = Some(value.extract()?), "lstrip" => token.lstrip = Some(value.extract()?),
"rstrip" => token.rstrip = Some(value.extract()?), "rstrip" => token.rstrip = Some(value.extract()?),
"normalized" => token.normalized = Some(value.extract()?), "normalized" => token.normalized = Some(value.extract()?),
"special" => token.special = Some(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key), _ => println!("Ignored unknown kwarg option {}", key),
} }
} }
@ -161,6 +163,7 @@ impl PyAddedToken {
"lstrip" => self.lstrip = Some(value.extract()?), "lstrip" => self.lstrip = Some(value.extract()?),
"rstrip" => self.rstrip = Some(value.extract()?), "rstrip" => self.rstrip = Some(value.extract()?),
"normalized" => self.normalized = Some(value.extract()?), "normalized" => self.normalized = Some(value.extract()?),
"special" => self.special = Some(value.extract()?),
_ => {} _ => {}
} }
} }
@ -199,6 +202,11 @@ impl PyAddedToken {
fn get_normalized(&self) -> bool { fn get_normalized(&self) -> bool {
self.get_token().normalized self.get_token().normalized
} }
/// Get the value of the :obj:`special` option
#[getter]
fn get_special(&self) -> bool {
self.get_token().special
}
fn __str__(&self) -> PyResult<&str> { fn __str__(&self) -> PyResult<&str> {
Ok(&self.content) Ok(&self.content)
@ -212,12 +220,13 @@ impl PyAddedToken {
let token = self.get_token(); let token = self.get_token();
Ok(format!( Ok(format!(
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})", "AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={}, special={})",
self.content, self.content,
bool_to_python(token.rstrip), bool_to_python(token.rstrip),
bool_to_python(token.lstrip), bool_to_python(token.lstrip),
bool_to_python(token.single_word), bool_to_python(token.single_word),
bool_to_python(token.normalized) bool_to_python(token.normalized),
bool_to_python(token.special)
)) ))
} }
@ -1090,7 +1099,7 @@ impl PyTokenizer {
if let Ok(content) = token.extract::<String>() { if let Ok(content) = token.extract::<String>() {
Ok(PyAddedToken::from(content, Some(false)).get_token()) Ok(PyAddedToken::from(content, Some(false)).get_token())
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() { } else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = false; token.special = false;
Ok(token.get_token()) Ok(token.get_token())
} else { } else {
Err(exceptions::PyTypeError::new_err( Err(exceptions::PyTypeError::new_err(
@ -1127,7 +1136,7 @@ impl PyTokenizer {
if let Ok(content) = token.extract::<String>() { if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true)) Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() { } else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true; token.special = true;
Ok(token.get_token()) Ok(token.get_token())
} else { } else {
Err(exceptions::PyTypeError::new_err( Err(exceptions::PyTypeError::new_err(