mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-04 00:09:34 +00:00
Python - Improve AddedToken interface
This commit is contained in:
@ -22,94 +22,133 @@ use tk::tokenizer::{
|
|||||||
|
|
||||||
#[pyclass(dict, module = "tokenizers")]
|
#[pyclass(dict, module = "tokenizers")]
|
||||||
pub struct AddedToken {
|
pub struct AddedToken {
|
||||||
pub token: tk::tokenizer::AddedToken,
|
pub content: String,
|
||||||
|
pub is_special_token: bool,
|
||||||
|
pub single_word: Option<bool>,
|
||||||
|
pub lstrip: Option<bool>,
|
||||||
|
pub rstrip: Option<bool>,
|
||||||
|
pub normalized: Option<bool>,
|
||||||
}
|
}
|
||||||
|
impl AddedToken {
|
||||||
|
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
|
||||||
|
Self {
|
||||||
|
content: content.into(),
|
||||||
|
is_special_token: is_special_token.unwrap_or(false),
|
||||||
|
single_word: None,
|
||||||
|
lstrip: None,
|
||||||
|
rstrip: None,
|
||||||
|
normalized: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
|
||||||
|
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
|
||||||
|
|
||||||
|
if let Some(sw) = self.single_word {
|
||||||
|
token = token.single_word(sw);
|
||||||
|
}
|
||||||
|
if let Some(ls) = self.lstrip {
|
||||||
|
token = token.lstrip(ls);
|
||||||
|
}
|
||||||
|
if let Some(rs) = self.rstrip {
|
||||||
|
token = token.rstrip(rs);
|
||||||
|
}
|
||||||
|
if let Some(n) = self.normalized {
|
||||||
|
token = token.normalized(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
token
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_pydict<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> {
|
||||||
|
let dict = PyDict::new(py);
|
||||||
|
let token = self.get_token();
|
||||||
|
|
||||||
|
dict.set_item("content", token.content)?;
|
||||||
|
dict.set_item("single_word", token.single_word)?;
|
||||||
|
dict.set_item("lstrip", token.lstrip)?;
|
||||||
|
dict.set_item("rstrip", token.rstrip)?;
|
||||||
|
dict.set_item("normalized", token.normalized)?;
|
||||||
|
|
||||||
|
Ok(dict)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl AddedToken {
|
impl AddedToken {
|
||||||
#[new]
|
#[new]
|
||||||
#[args(kwargs = "**")]
|
#[args(kwargs = "**")]
|
||||||
fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
fn new(content: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||||
let mut token = tk::tokenizer::AddedToken::from(content, is_special_token);
|
let mut token = AddedToken::from(content.unwrap_or(""), None);
|
||||||
|
|
||||||
if let Some(kwargs) = kwargs {
|
if let Some(kwargs) = kwargs {
|
||||||
for (key, value) in kwargs {
|
for (key, value) in kwargs {
|
||||||
let key: &str = key.extract()?;
|
let key: &str = key.extract()?;
|
||||||
match key {
|
match key {
|
||||||
"single_word" => token = token.single_word(value.extract()?),
|
"single_word" => token.single_word = Some(value.extract()?),
|
||||||
"lstrip" => token = token.lstrip(value.extract()?),
|
"lstrip" => token.lstrip = Some(value.extract()?),
|
||||||
"rstrip" => token = token.rstrip(value.extract()?),
|
"rstrip" => token.rstrip = Some(value.extract()?),
|
||||||
"normalized" => token = token.normalized(value.extract()?),
|
"normalized" => token.normalized = Some(value.extract()?),
|
||||||
_ => println!("Ignored unknown kwarg option {}", key),
|
_ => println!("Ignored unknown kwarg option {}", key),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(AddedToken { token })
|
Ok(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> {
|
||||||
let data = serde_json::to_string(&self.token).map_err(|e| {
|
self.as_pydict(py)
|
||||||
exceptions::Exception::py_err(format!(
|
|
||||||
"Error while attempting to pickle AddedToken: {}",
|
|
||||||
e.to_string()
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
|
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
|
||||||
match state.extract::<&PyBytes>(py) {
|
match state.extract::<&PyDict>(py) {
|
||||||
Ok(s) => {
|
Ok(state) => {
|
||||||
self.token = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
for (key, value) in state {
|
||||||
exceptions::Exception::py_err(format!(
|
let key: &str = key.extract()?;
|
||||||
"Error while attempting to unpickle AddedToken: {}",
|
match key {
|
||||||
e.to_string()
|
"single_word" => self.single_word = Some(value.extract()?),
|
||||||
))
|
"lstrip" => self.lstrip = Some(value.extract()?),
|
||||||
})?;
|
"rstrip" => self.rstrip = Some(value.extract()?),
|
||||||
|
"normalized" => self.normalized = Some(value.extract()?),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
Err(e) => Err(e),
|
Err(e) => Err(e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
|
|
||||||
// We don't really care about the values of `content` & `is_special_token` here because
|
|
||||||
// they will get overriden by `__setstate__`
|
|
||||||
let content: PyObject = "".into_py(py);
|
|
||||||
let is_special_token: PyObject = false.into_py(py);
|
|
||||||
let args = PyTuple::new(py, vec![content, is_special_token]);
|
|
||||||
Ok(args)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_content(&self) -> &str {
|
fn get_content(&self) -> &str {
|
||||||
&self.token.content
|
&self.content
|
||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_rstrip(&self) -> bool {
|
fn get_rstrip(&self) -> bool {
|
||||||
self.token.rstrip
|
self.get_token().rstrip
|
||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_lstrip(&self) -> bool {
|
fn get_lstrip(&self) -> bool {
|
||||||
self.token.lstrip
|
self.get_token().lstrip
|
||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_single_word(&self) -> bool {
|
fn get_single_word(&self) -> bool {
|
||||||
self.token.single_word
|
self.get_token().single_word
|
||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_normalized(&self) -> bool {
|
fn get_normalized(&self) -> bool {
|
||||||
self.token.normalized
|
self.get_token().normalized
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#[pyproto]
|
#[pyproto]
|
||||||
impl PyObjectProtocol for AddedToken {
|
impl PyObjectProtocol for AddedToken {
|
||||||
fn __str__(&'p self) -> PyResult<&'p str> {
|
fn __str__(&'p self) -> PyResult<&'p str> {
|
||||||
Ok(&self.token.content)
|
Ok(&self.content)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __repr__(&self) -> PyResult<String> {
|
fn __repr__(&self) -> PyResult<String> {
|
||||||
@ -118,13 +157,14 @@ impl PyObjectProtocol for AddedToken {
|
|||||||
false => "False",
|
false => "False",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let token = self.get_token();
|
||||||
Ok(format!(
|
Ok(format!(
|
||||||
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})",
|
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})",
|
||||||
self.token.content,
|
self.content,
|
||||||
bool_to_python(self.token.rstrip),
|
bool_to_python(token.rstrip),
|
||||||
bool_to_python(self.token.lstrip),
|
bool_to_python(token.lstrip),
|
||||||
bool_to_python(self.token.single_word),
|
bool_to_python(token.single_word),
|
||||||
bool_to_python(self.token.normalized)
|
bool_to_python(token.normalized)
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -583,9 +623,10 @@ impl Tokenizer {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|token| {
|
.map(|token| {
|
||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken::from(content, false))
|
Ok(AddedToken::from(content, Some(false)).get_token())
|
||||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
} else if let Ok(mut token) = token.extract::<PyRefMut<AddedToken>>() {
|
||||||
Ok(token.token.clone())
|
token.is_special_token = false;
|
||||||
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::Exception::py_err(
|
Err(exceptions::Exception::py_err(
|
||||||
"Input must be a List[Union[str, AddedToken]]",
|
"Input must be a List[Union[str, AddedToken]]",
|
||||||
@ -603,8 +644,9 @@ impl Tokenizer {
|
|||||||
.map(|token| {
|
.map(|token| {
|
||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
} else if let Ok(mut token) = token.extract::<PyRefMut<AddedToken>>() {
|
||||||
Ok(token.token.clone())
|
token.is_special_token = true;
|
||||||
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::Exception::py_err(
|
Err(exceptions::Exception::py_err(
|
||||||
"Input must be a List[Union[str, AddedToken]]",
|
"Input must be a List[Union[str, AddedToken]]",
|
||||||
|
@ -36,12 +36,12 @@ impl BpeTrainer {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|token| {
|
.map(|token| {
|
||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken {
|
Ok(AddedToken::from(content, Some(true)).get_token())
|
||||||
content,
|
} else if let Ok(mut token) =
|
||||||
..Default::default()
|
token.extract::<PyRefMut<AddedToken>>()
|
||||||
})
|
{
|
||||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
token.is_special_token = true;
|
||||||
Ok(token.token.clone())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::Exception::py_err(
|
Err(exceptions::Exception::py_err(
|
||||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||||
@ -105,12 +105,12 @@ impl WordPieceTrainer {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|token| {
|
.map(|token| {
|
||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken {
|
Ok(AddedToken::from(content, Some(true)).get_token())
|
||||||
content,
|
} else if let Ok(mut token) =
|
||||||
..Default::default()
|
token.extract::<PyRefMut<AddedToken>>()
|
||||||
})
|
{
|
||||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
token.is_special_token = true;
|
||||||
Ok(token.token.clone())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::Exception::py_err(
|
Err(exceptions::Exception::py_err(
|
||||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||||
|
@ -12,49 +12,46 @@ from tokenizers.implementations import BertWordPieceTokenizer
|
|||||||
|
|
||||||
class TestAddedToken:
|
class TestAddedToken:
|
||||||
def test_instantiate_with_content_only(self):
|
def test_instantiate_with_content_only(self):
|
||||||
added_token = AddedToken("<mask>", True)
|
added_token = AddedToken("<mask>")
|
||||||
assert type(added_token) == AddedToken
|
assert type(added_token) == AddedToken
|
||||||
assert str(added_token) == "<mask>"
|
assert str(added_token) == "<mask>"
|
||||||
assert (
|
assert (
|
||||||
repr(added_token)
|
repr(added_token)
|
||||||
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False)'
|
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True)'
|
||||||
)
|
)
|
||||||
assert added_token.rstrip == False
|
assert added_token.rstrip == False
|
||||||
assert added_token.lstrip == False
|
assert added_token.lstrip == False
|
||||||
assert added_token.single_word == False
|
assert added_token.single_word == False
|
||||||
assert added_token.normalized == False
|
assert added_token.normalized == True
|
||||||
assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken)
|
assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken)
|
||||||
|
|
||||||
def test_can_set_rstrip(self):
|
def test_can_set_rstrip(self):
|
||||||
added_token = AddedToken("<mask>", True, rstrip=True)
|
added_token = AddedToken("<mask>", rstrip=True)
|
||||||
assert added_token.rstrip == True
|
assert added_token.rstrip == True
|
||||||
assert added_token.lstrip == False
|
assert added_token.lstrip == False
|
||||||
assert added_token.single_word == False
|
assert added_token.single_word == False
|
||||||
|
assert added_token.normalized == True
|
||||||
|
|
||||||
def test_can_set_lstrip(self):
|
def test_can_set_lstrip(self):
|
||||||
added_token = AddedToken("<mask>", True, lstrip=True)
|
added_token = AddedToken("<mask>", lstrip=True)
|
||||||
assert added_token.rstrip == False
|
assert added_token.rstrip == False
|
||||||
assert added_token.lstrip == True
|
assert added_token.lstrip == True
|
||||||
assert added_token.single_word == False
|
assert added_token.single_word == False
|
||||||
|
assert added_token.normalized == True
|
||||||
|
|
||||||
def test_can_set_single_world(self):
|
def test_can_set_single_world(self):
|
||||||
added_token = AddedToken("<mask>", True, single_word=True)
|
added_token = AddedToken("<mask>", single_word=True)
|
||||||
assert added_token.rstrip == False
|
assert added_token.rstrip == False
|
||||||
assert added_token.lstrip == False
|
assert added_token.lstrip == False
|
||||||
assert added_token.single_word == True
|
assert added_token.single_word == True
|
||||||
|
assert added_token.normalized == True
|
||||||
|
|
||||||
def test_can_set_normalized(self):
|
def test_can_set_normalized(self):
|
||||||
added_token = AddedToken("<mask>", True, normalized=True)
|
added_token = AddedToken("<mask>", normalized=False)
|
||||||
assert added_token.rstrip == False
|
assert added_token.rstrip == False
|
||||||
assert added_token.lstrip == False
|
assert added_token.lstrip == False
|
||||||
assert added_token.single_word == False
|
assert added_token.single_word == False
|
||||||
assert added_token.normalized == True
|
|
||||||
|
|
||||||
def test_second_argument_defines_normalized(self):
|
|
||||||
added_token = AddedToken("<mask>", True)
|
|
||||||
assert added_token.normalized == False
|
assert added_token.normalized == False
|
||||||
added_token = AddedToken("<mask>", False)
|
|
||||||
assert added_token.normalized == True
|
|
||||||
|
|
||||||
|
|
||||||
class TestTokenizer:
|
class TestTokenizer:
|
||||||
@ -91,10 +88,12 @@ class TestTokenizer:
|
|||||||
added = tokenizer.add_tokens(["my", "name", "is", "john"])
|
added = tokenizer.add_tokens(["my", "name", "is", "john"])
|
||||||
assert added == 4
|
assert added == 4
|
||||||
|
|
||||||
added = tokenizer.add_tokens(
|
tokens = [AddedToken("the"), AddedToken("quick", normalized=False), AddedToken()]
|
||||||
[AddedToken("the", False), AddedToken("quick", False, rstrip=True)]
|
assert tokens[0].normalized == True
|
||||||
)
|
added = tokenizer.add_tokens(tokens)
|
||||||
assert added == 2
|
assert added == 2
|
||||||
|
assert tokens[0].normalized == True
|
||||||
|
assert tokens[1].normalized == False
|
||||||
|
|
||||||
def test_add_special_tokens(self):
|
def test_add_special_tokens(self):
|
||||||
tokenizer = Tokenizer(BPE())
|
tokenizer = Tokenizer(BPE())
|
||||||
@ -104,10 +103,12 @@ class TestTokenizer:
|
|||||||
assert added == 4
|
assert added == 4
|
||||||
|
|
||||||
# Can add special tokens as `AddedToken`
|
# Can add special tokens as `AddedToken`
|
||||||
added = tokenizer.add_special_tokens(
|
tokens = [AddedToken("the"), AddedToken("quick", normalized=True), AddedToken()]
|
||||||
[AddedToken("the", False), AddedToken("quick", False, rstrip=True)]
|
assert tokens[0].normalized == True
|
||||||
)
|
added = tokenizer.add_special_tokens(tokens)
|
||||||
assert added == 2
|
assert added == 2
|
||||||
|
assert tokens[0].normalized == False
|
||||||
|
assert tokens[1].normalized == True
|
||||||
|
|
||||||
def test_encode(self):
|
def test_encode(self):
|
||||||
tokenizer = Tokenizer(BPE())
|
tokenizer = Tokenizer(BPE())
|
||||||
|
@ -201,8 +201,7 @@ class AddedToken:
|
|||||||
|
|
||||||
def __new__(
|
def __new__(
|
||||||
cls,
|
cls,
|
||||||
content: str,
|
content: str = "",
|
||||||
is_special_token: bool,
|
|
||||||
single_word: bool = False,
|
single_word: bool = False,
|
||||||
lstrip: bool = False,
|
lstrip: bool = False,
|
||||||
rstrip: bool = False,
|
rstrip: bool = False,
|
||||||
@ -214,10 +213,6 @@ class AddedToken:
|
|||||||
content: str:
|
content: str:
|
||||||
The content of the token
|
The content of the token
|
||||||
|
|
||||||
is_special_token: bool:
|
|
||||||
Whether this token is a special token. This has an impact on the default value for
|
|
||||||
`normalized` which is False for special tokens, but True for others.
|
|
||||||
|
|
||||||
single_word: bool
|
single_word: bool
|
||||||
Whether this token should only match against single words. If True,
|
Whether this token should only match against single words. If True,
|
||||||
this token will never match inside of a word. For example the token `ing` would
|
this token will never match inside of a word. For example the token `ing` would
|
||||||
|
@ -219,7 +219,7 @@ impl AddedVocabulary {
|
|||||||
normalizer: Option<&dyn Normalizer>,
|
normalizer: Option<&dyn Normalizer>,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
for token in tokens {
|
for token in tokens {
|
||||||
if !self.special_tokens_set.contains(&token.content) {
|
if !token.content.is_empty() && !self.special_tokens_set.contains(&token.content) {
|
||||||
self.special_tokens.push(token.to_owned());
|
self.special_tokens.push(token.to_owned());
|
||||||
self.special_tokens_set.insert(token.content.clone());
|
self.special_tokens_set.insert(token.content.clone());
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user