mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-03 19:28:20 +00:00
make tests happy
This commit is contained in:
@@ -89,6 +89,7 @@ from .tokenizers import (
|
|||||||
pre_tokenizers,
|
pre_tokenizers,
|
||||||
processors,
|
processors,
|
||||||
trainers,
|
trainers,
|
||||||
|
__version__,
|
||||||
)
|
)
|
||||||
from .implementations import (
|
from .implementations import (
|
||||||
BertWordPieceTokenizer,
|
BertWordPieceTokenizer,
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ impl PyAddedToken {
|
|||||||
single_word: None,
|
single_word: None,
|
||||||
lstrip: None,
|
lstrip: None,
|
||||||
rstrip: None,
|
rstrip: None,
|
||||||
normalized: None,
|
normalized: Some(!special.unwrap_or(true)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ impl PyBpeTrainer {
|
|||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||||
token.special = false;
|
token.special = true;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class TestAddedToken:
|
|||||||
assert added_token.special == False
|
assert added_token.special == False
|
||||||
added_token.special = True
|
added_token.special = True
|
||||||
assert added_token.special == True
|
assert added_token.special == True
|
||||||
|
added_token.special = False
|
||||||
assert str(added_token) == "<mask>"
|
assert str(added_token) == "<mask>"
|
||||||
assert (
|
assert (
|
||||||
repr(added_token)
|
repr(added_token)
|
||||||
|
|||||||
@@ -34,8 +34,8 @@ class TestBpeTrainer:
|
|||||||
assert trainer.min_frequency == 12
|
assert trainer.min_frequency == 12
|
||||||
assert trainer.show_progress == False
|
assert trainer.show_progress == False
|
||||||
assert trainer.special_tokens == [
|
assert trainer.special_tokens == [
|
||||||
AddedToken("1"),
|
AddedToken("1", special = True),
|
||||||
AddedToken("2"),
|
AddedToken("2", special = True),
|
||||||
]
|
]
|
||||||
assert trainer.limit_alphabet == 13
|
assert trainer.limit_alphabet == 13
|
||||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||||
@@ -91,8 +91,8 @@ class TestWordPieceTrainer:
|
|||||||
assert trainer.min_frequency == 12
|
assert trainer.min_frequency == 12
|
||||||
assert trainer.show_progress == False
|
assert trainer.show_progress == False
|
||||||
assert trainer.special_tokens == [
|
assert trainer.special_tokens == [
|
||||||
AddedToken("1"),
|
AddedToken("1", special = True),
|
||||||
AddedToken("2"),
|
AddedToken("2", special = True),
|
||||||
]
|
]
|
||||||
assert trainer.limit_alphabet == 13
|
assert trainer.limit_alphabet == 13
|
||||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||||
@@ -131,8 +131,8 @@ class TestWordLevelTrainer:
|
|||||||
assert trainer.min_frequency == 12
|
assert trainer.min_frequency == 12
|
||||||
assert trainer.show_progress == False
|
assert trainer.show_progress == False
|
||||||
assert trainer.special_tokens == [
|
assert trainer.special_tokens == [
|
||||||
AddedToken("1"),
|
AddedToken("1", special = True),
|
||||||
AddedToken("2"),
|
AddedToken("2", special = True),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Modify these
|
# Modify these
|
||||||
@@ -272,8 +272,8 @@ class TestUnigram:
|
|||||||
assert trainer.vocab_size == 12345
|
assert trainer.vocab_size == 12345
|
||||||
assert trainer.show_progress == False
|
assert trainer.show_progress == False
|
||||||
assert trainer.special_tokens == [
|
assert trainer.special_tokens == [
|
||||||
AddedToken("1", normalized=False),
|
AddedToken("1", normalized=False, special = True),
|
||||||
AddedToken("2", lstrip=True, normalized=False),
|
AddedToken("2", lstrip=True, normalized=False, special = True),
|
||||||
]
|
]
|
||||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||||
|
|
||||||
|
|||||||
@@ -609,7 +609,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
assert_eq!(vocab.len(), 3);
|
assert_eq!(vocab.len(), 3);
|
||||||
|
|
||||||
assert_eq!(vocab.get_vocab_r()[&0], added_token);
|
assert_eq!(vocab.get_added_tokens_decoder()[&0], added_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -650,7 +650,7 @@ mod tests {
|
|||||||
assert_eq!(vocab.len(), 3); // New token was added
|
assert_eq!(vocab.len(), 3); // New token was added
|
||||||
assert!(vocab.is_special_token("test"));
|
assert!(vocab.is_special_token("test"));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
*vocab.get_vocab_r(),
|
*vocab.get_added_tokens_decoder(),
|
||||||
HashMap::from([
|
HashMap::from([
|
||||||
(0, AddedToken::from("test", true)),
|
(0, AddedToken::from("test", true)),
|
||||||
(2, AddedToken::from("added_token_1", true)),
|
(2, AddedToken::from("added_token_1", true)),
|
||||||
|
|||||||
@@ -669,7 +669,7 @@ where
|
|||||||
// TODO ArthurZ THIS IS WRONG! We need to measure the length of the `set` because
|
// TODO ArthurZ THIS IS WRONG! We need to measure the length of the `set` because
|
||||||
// now some tokens can be both in the added_tokens_encoder and in the vocab
|
// now some tokens can be both in the added_tokens_encoder and in the vocab
|
||||||
if with_added_tokens {
|
if with_added_tokens {
|
||||||
self.get_vocab(with_added_tokens).len()
|
self.get_vocab(true).len()
|
||||||
} else {
|
} else {
|
||||||
self.model.get_vocab_size()
|
self.model.get_vocab_size()
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user