mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
add support for get_added_tokens_decoder
This commit is contained in:
@ -42,6 +42,14 @@ class BaseTokenizer:
|
||||
"""
|
||||
return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens)
|
||||
|
||||
def get_added_tokens_decoder(self) -> Dict[int, AddedToken]:
|
||||
"""Returns the added reverse vocabulary
|
||||
|
||||
Returns:
|
||||
The added vocabulary mapping ints to AddedTokens
|
||||
"""
|
||||
return self._tokenizer.get_added_tokens_decoder()
|
||||
|
||||
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
|
||||
"""Return the size of vocabulary, with or without added tokens.
|
||||
|
||||
|
@ -662,6 +662,17 @@ impl PyTokenizer {
|
||||
self.tokenizer.get_vocab(with_added_tokens)
|
||||
}
|
||||
|
||||
/// Get the underlying vocabulary
|
||||
///
|
||||
/// Returns:
|
||||
/// :obj:`Dict[int, AddedToken]`: The vocabulary
|
||||
#[pyo3(signature = ())]
|
||||
#[pyo3(text_signature = "(self)")]
|
||||
fn get_added_tokens_decoder(&self) -> HashMap<u32, PyAddedToken> {
|
||||
self.tokenizer.get_added_tokens_decoder().into_iter().map(|(key, value)| (key, value.into())).collect()
|
||||
}
|
||||
|
||||
|
||||
/// Get the size of the underlying vocabulary
|
||||
///
|
||||
/// Args:
|
||||
|
@ -373,6 +373,10 @@ class TestTokenizer:
|
||||
# Can retrieve vocab without added tokens
|
||||
vocab = tokenizer.get_vocab(with_added_tokens=False)
|
||||
assert vocab == {}
|
||||
|
||||
# Can retrieve added token decoder
|
||||
vocab = tokenizer.get_added_tokens_decoder()
|
||||
assert vocab == {0: AddedToken("my", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),1: AddedToken("name", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),2: AddedToken("is", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),3: AddedToken("john", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),4: AddedToken("pair", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False)}
|
||||
|
||||
def test_get_vocab_size(self):
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
@ -192,7 +192,7 @@ impl AddedVocabulary {
|
||||
}
|
||||
|
||||
/// Get the additional vocabulary with the AddedTokens
|
||||
pub fn get_vocab_r(&self) -> &HashMap<u32, AddedToken> {
|
||||
pub fn get_added_tokens_decoder(&self) -> &HashMap<u32, AddedToken> {
|
||||
&self.added_tokens_map_r
|
||||
}
|
||||
|
||||
@ -260,7 +260,7 @@ impl AddedVocabulary {
|
||||
self.added_tokens_map.values().cloned().max().map_or(
|
||||
model.get_vocab_size() as u32,
|
||||
|max| {
|
||||
if max >= (model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
|
||||
if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
|
||||
max + 1
|
||||
} else {
|
||||
model.get_vocab_size() as u32
|
||||
|
@ -659,11 +659,9 @@ where
|
||||
final_vocab
|
||||
}
|
||||
|
||||
/// Get the added vocabulary only
|
||||
|
||||
/// Get the added tokens decoder
|
||||
pub fn get_added_tokens_decoder(&self) -> &HashMap<u32, AddedToken> {
|
||||
self.added_vocabulary.get_vocab_r()
|
||||
pub fn get_added_tokens_decoder(&self) -> HashMap<u32, AddedToken> {
|
||||
self.added_vocabulary.get_added_tokens_decoder().clone()
|
||||
}
|
||||
|
||||
/// Get the size of the vocabulary
|
||||
|
Reference in New Issue
Block a user