make tests happy

This commit is contained in:
Arthur Zucker
2023-09-05 15:37:09 +00:00
parent 531b06f6db
commit 08af8ea9c3
7 changed files with 15 additions and 14 deletions

View File

@@ -89,6 +89,7 @@ from .tokenizers import (
pre_tokenizers, pre_tokenizers,
processors, processors,
trainers, trainers,
__version__,
) )
from .implementations import ( from .implementations import (
BertWordPieceTokenizer, BertWordPieceTokenizer,

View File

@@ -75,7 +75,7 @@ impl PyAddedToken {
single_word: None, single_word: None,
lstrip: None, lstrip: None,
rstrip: None, rstrip: None,
normalized: None, normalized: Some(!special.unwrap_or(true)),
} }
} }

View File

@@ -226,7 +226,7 @@ impl PyBpeTrainer {
if let Ok(content) = token.extract::<String>() { if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true)) Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() { } else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.special = false; token.special = true;
Ok(token.get_token()) Ok(token.get_token())
} else { } else {
Err(exceptions::PyTypeError::new_err( Err(exceptions::PyTypeError::new_err(

View File

@@ -24,7 +24,7 @@ class TestAddedToken:
assert added_token.special == False assert added_token.special == False
added_token.special = True added_token.special = True
assert added_token.special == True assert added_token.special == True
added_token.special = False
assert str(added_token) == "<mask>" assert str(added_token) == "<mask>"
assert ( assert (
repr(added_token) repr(added_token)

View File

@@ -34,8 +34,8 @@ class TestBpeTrainer:
assert trainer.min_frequency == 12 assert trainer.min_frequency == 12
assert trainer.show_progress == False assert trainer.show_progress == False
assert trainer.special_tokens == [ assert trainer.special_tokens == [
AddedToken("1"), AddedToken("1", special = True),
AddedToken("2"), AddedToken("2", special = True),
] ]
assert trainer.limit_alphabet == 13 assert trainer.limit_alphabet == 13
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"] assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
@@ -91,8 +91,8 @@ class TestWordPieceTrainer:
assert trainer.min_frequency == 12 assert trainer.min_frequency == 12
assert trainer.show_progress == False assert trainer.show_progress == False
assert trainer.special_tokens == [ assert trainer.special_tokens == [
AddedToken("1"), AddedToken("1", special = True),
AddedToken("2"), AddedToken("2", special = True),
] ]
assert trainer.limit_alphabet == 13 assert trainer.limit_alphabet == 13
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"] assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
@@ -131,8 +131,8 @@ class TestWordLevelTrainer:
assert trainer.min_frequency == 12 assert trainer.min_frequency == 12
assert trainer.show_progress == False assert trainer.show_progress == False
assert trainer.special_tokens == [ assert trainer.special_tokens == [
AddedToken("1"), AddedToken("1", special = True),
AddedToken("2"), AddedToken("2", special = True),
] ]
# Modify these # Modify these
@@ -272,8 +272,8 @@ class TestUnigram:
assert trainer.vocab_size == 12345 assert trainer.vocab_size == 12345
assert trainer.show_progress == False assert trainer.show_progress == False
assert trainer.special_tokens == [ assert trainer.special_tokens == [
AddedToken("1", normalized=False), AddedToken("1", normalized=False, special = True),
AddedToken("2", lstrip=True, normalized=False), AddedToken("2", lstrip=True, normalized=False, special = True),
] ]
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"] assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]

View File

@@ -609,7 +609,7 @@ mod tests {
); );
assert_eq!(vocab.len(), 3); assert_eq!(vocab.len(), 3);
assert_eq!(vocab.get_vocab_r()[&0], added_token); assert_eq!(vocab.get_added_tokens_decoder()[&0], added_token);
} }
#[test] #[test]
@@ -650,7 +650,7 @@ mod tests {
assert_eq!(vocab.len(), 3); // New token was added assert_eq!(vocab.len(), 3); // New token was added
assert!(vocab.is_special_token("test")); assert!(vocab.is_special_token("test"));
assert_eq!( assert_eq!(
*vocab.get_vocab_r(), *vocab.get_added_tokens_decoder(),
HashMap::from([ HashMap::from([
(0, AddedToken::from("test", true)), (0, AddedToken::from("test", true)),
(2, AddedToken::from("added_token_1", true)), (2, AddedToken::from("added_token_1", true)),

View File

@@ -669,7 +669,7 @@ where
// TODO ArthurZ THIS IS WRONG! We need to measure the length of the `set` because // TODO ArthurZ THIS IS WRONG! We need to measure the length of the `set` because
// now some tokens can be both in the added_tokens_encoder and in the vocab // now some tokens can be both in the added_tokens_encoder and in the vocab
if with_added_tokens { if with_added_tokens {
self.get_vocab(with_added_tokens).len() self.get_vocab(true).len()
} else { } else {
self.model.get_vocab_size() self.model.get_vocab_size()
} }