mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
add support for get_added_tokens_decoder
This commit is contained in:
@ -42,6 +42,14 @@ class BaseTokenizer:
|
|||||||
"""
|
"""
|
||||||
return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens)
|
return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens)
|
||||||
|
|
||||||
|
def get_added_tokens_decoder(self) -> Dict[int, AddedToken]:
|
||||||
|
"""Returns the added reverse vocabulary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The added vocabulary mapping ints to AddedTokens
|
||||||
|
"""
|
||||||
|
return self._tokenizer.get_added_tokens_decoder()
|
||||||
|
|
||||||
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
|
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
|
||||||
"""Return the size of vocabulary, with or without added tokens.
|
"""Return the size of vocabulary, with or without added tokens.
|
||||||
|
|
||||||
|
@ -662,6 +662,17 @@ impl PyTokenizer {
|
|||||||
self.tokenizer.get_vocab(with_added_tokens)
|
self.tokenizer.get_vocab(with_added_tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the underlying vocabulary
|
||||||
|
///
|
||||||
|
/// Returns:
|
||||||
|
/// :obj:`Dict[int, AddedToken]`: The vocabulary
|
||||||
|
#[pyo3(signature = ())]
|
||||||
|
#[pyo3(text_signature = "(self)")]
|
||||||
|
fn get_added_tokens_decoder(&self) -> HashMap<u32, PyAddedToken> {
|
||||||
|
self.tokenizer.get_added_tokens_decoder().into_iter().map(|(key, value)| (key, value.into())).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Get the size of the underlying vocabulary
|
/// Get the size of the underlying vocabulary
|
||||||
///
|
///
|
||||||
/// Args:
|
/// Args:
|
||||||
|
@ -373,6 +373,10 @@ class TestTokenizer:
|
|||||||
# Can retrieve vocab without added tokens
|
# Can retrieve vocab without added tokens
|
||||||
vocab = tokenizer.get_vocab(with_added_tokens=False)
|
vocab = tokenizer.get_vocab(with_added_tokens=False)
|
||||||
assert vocab == {}
|
assert vocab == {}
|
||||||
|
|
||||||
|
# Can retrieve added token decoder
|
||||||
|
vocab = tokenizer.get_added_tokens_decoder()
|
||||||
|
assert vocab == {0: AddedToken("my", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),1: AddedToken("name", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),2: AddedToken("is", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),3: AddedToken("john", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),4: AddedToken("pair", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False)}
|
||||||
|
|
||||||
def test_get_vocab_size(self):
|
def test_get_vocab_size(self):
|
||||||
tokenizer = Tokenizer(BPE())
|
tokenizer = Tokenizer(BPE())
|
||||||
|
@ -192,7 +192,7 @@ impl AddedVocabulary {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Get the additional vocabulary with the AddedTokens
|
/// Get the additional vocabulary with the AddedTokens
|
||||||
pub fn get_vocab_r(&self) -> &HashMap<u32, AddedToken> {
|
pub fn get_added_tokens_decoder(&self) -> &HashMap<u32, AddedToken> {
|
||||||
&self.added_tokens_map_r
|
&self.added_tokens_map_r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,7 +260,7 @@ impl AddedVocabulary {
|
|||||||
self.added_tokens_map.values().cloned().max().map_or(
|
self.added_tokens_map.values().cloned().max().map_or(
|
||||||
model.get_vocab_size() as u32,
|
model.get_vocab_size() as u32,
|
||||||
|max| {
|
|max| {
|
||||||
if max >= (model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
|
if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
|
||||||
max + 1
|
max + 1
|
||||||
} else {
|
} else {
|
||||||
model.get_vocab_size() as u32
|
model.get_vocab_size() as u32
|
||||||
|
@ -659,11 +659,9 @@ where
|
|||||||
final_vocab
|
final_vocab
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the added vocabulary only
|
|
||||||
|
|
||||||
/// Get the added tokens decoder
|
/// Get the added tokens decoder
|
||||||
pub fn get_added_tokens_decoder(&self) -> &HashMap<u32, AddedToken> {
|
pub fn get_added_tokens_decoder(&self) -> HashMap<u32, AddedToken> {
|
||||||
self.added_vocabulary.get_vocab_r()
|
self.added_vocabulary.get_added_tokens_decoder().clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the size of the vocabulary
|
/// Get the size of the vocabulary
|
||||||
|
Reference in New Issue
Block a user