diff --git a/bindings/python/py_src/tokenizers/__init__.py b/bindings/python/py_src/tokenizers/__init__.py index dfbf0333..efd57429 100644 --- a/bindings/python/py_src/tokenizers/__init__.py +++ b/bindings/python/py_src/tokenizers/__init__.py @@ -89,6 +89,7 @@ from .tokenizers import ( pre_tokenizers, processors, trainers, + __version__, ) from .implementations import ( BertWordPieceTokenizer, diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index a491a036..7462c359 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -75,7 +75,7 @@ impl PyAddedToken { single_word: None, lstrip: None, rstrip: None, - normalized: None, + normalized: Some(!special.unwrap_or(true)), } } diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 54fcc8b0..1c1c9310 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -226,7 +226,7 @@ impl PyBpeTrainer { if let Ok(content) = token.extract::() { Ok(tk::tokenizer::AddedToken::from(content, true)) } else if let Ok(mut token) = token.extract::>() { - token.special = false; + token.special = true; Ok(token.get_token()) } else { Err(exceptions::PyTypeError::new_err( diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index abfa8f90..5c002046 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -24,7 +24,7 @@ class TestAddedToken: assert added_token.special == False added_token.special = True assert added_token.special == True - + added_token.special = False assert str(added_token) == "" assert ( repr(added_token) diff --git a/bindings/python/tests/bindings/test_trainers.py b/bindings/python/tests/bindings/test_trainers.py index 2bb8f014..48b69c26 100644 --- a/bindings/python/tests/bindings/test_trainers.py +++ b/bindings/python/tests/bindings/test_trainers.py @@ -34,8 +34,8 @@ class TestBpeTrainer: assert trainer.min_frequency == 12 assert trainer.show_progress == False assert trainer.special_tokens == [ - AddedToken("1"), - AddedToken("2"), + AddedToken("1", special = True), + AddedToken("2", special = True), ] assert trainer.limit_alphabet == 13 assert sorted(trainer.initial_alphabet) == ["a", "b", "c"] @@ -91,8 +91,8 @@ class TestWordPieceTrainer: assert trainer.min_frequency == 12 assert trainer.show_progress == False assert trainer.special_tokens == [ - AddedToken("1"), - AddedToken("2"), + AddedToken("1", special = True), + AddedToken("2", special = True), ] assert trainer.limit_alphabet == 13 assert sorted(trainer.initial_alphabet) == ["a", "b", "c"] @@ -131,8 +131,8 @@ class TestWordLevelTrainer: assert trainer.min_frequency == 12 assert trainer.show_progress == False assert trainer.special_tokens == [ - AddedToken("1"), - AddedToken("2"), + AddedToken("1", special = True), + AddedToken("2", special = True), ] # Modify these @@ -272,8 +272,8 @@ class TestUnigram: assert trainer.vocab_size == 12345 assert trainer.show_progress == False assert trainer.special_tokens == [ - AddedToken("1", normalized=False), - AddedToken("2", lstrip=True, normalized=False), + AddedToken("1", normalized=False, special = True), + AddedToken("2", lstrip=True, normalized=False, special = True), ] assert sorted(trainer.initial_alphabet) == ["a", "b", "c"] diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 77a72f62..49ced2db 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -609,7 +609,7 @@ mod tests { ); assert_eq!(vocab.len(), 3); - assert_eq!(vocab.get_vocab_r()[&0], added_token); + assert_eq!(vocab.get_added_tokens_decoder()[&0], added_token); } #[test] @@ -650,7 +650,7 @@ mod tests { assert_eq!(vocab.len(), 3); // New token was added assert!(vocab.is_special_token("test")); assert_eq!( - *vocab.get_vocab_r(), + *vocab.get_added_tokens_decoder(), HashMap::from([ (0, AddedToken::from("test", true)), (2, AddedToken::from("added_token_1", true)), diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 97390d91..915fe7bf 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -669,7 +669,7 @@ where // TODO ArthurZ THIS IS WRONG! We need to measure the length of the `set` because // now some tokens can be both in the added_tokens_encoder and in the vocab if with_added_tokens { - self.get_vocab(with_added_tokens).len() + self.get_vocab(true).len() } else { self.model.get_vocab_size() }