mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Rust - Hotfix special tokens with wrong id
This commit is contained in:
@ -757,7 +757,7 @@ impl Tokenizer {
|
||||
let id = if let Some(id) = self.token_to_id(&token.content) {
|
||||
id
|
||||
} else {
|
||||
let new_id = (self.model.get_vocab_size() + self.added_tokens.len()) as u32;
|
||||
let new_id = (self.model.get_vocab_size() + self.added_tokens_map.len()) as u32;
|
||||
self.added_tokens_map.insert(token.content.clone(), new_id);
|
||||
|
||||
if !self.special_tokens_set.contains(&token.content) {
|
||||
|
@ -3,6 +3,31 @@ mod common;
|
||||
use common::*;
|
||||
use tokenizers::tokenizer::{AddedToken, EncodeInput};
|
||||
|
||||
#[test]
|
||||
fn add_tokens() {
|
||||
let mut tokenizer = get_empty();
|
||||
|
||||
assert_eq!(
|
||||
tokenizer.add_special_tokens(&[
|
||||
AddedToken::from("<cls>".into()),
|
||||
AddedToken::from("<sep>".into())
|
||||
]),
|
||||
2
|
||||
);
|
||||
assert_eq!(tokenizer.token_to_id("<cls>"), Some(0));
|
||||
assert_eq!(tokenizer.token_to_id("<sep>"), Some(1));
|
||||
|
||||
assert_eq!(
|
||||
tokenizer.add_tokens(&[
|
||||
AddedToken::from("hello".into()),
|
||||
AddedToken::from("world".into())
|
||||
]),
|
||||
2
|
||||
);
|
||||
assert_eq!(tokenizer.token_to_id("hello"), Some(2));
|
||||
assert_eq!(tokenizer.token_to_id("world"), Some(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lstrip_tokens() {
|
||||
let mut tokenizer = get_byte_level(true, false);
|
||||
|
@ -7,6 +7,11 @@ use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
||||
use tokenizers::processors::bert::BertProcessing;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_empty() -> Tokenizer {
|
||||
Tokenizer::new(Box::new(BPE::default()))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
|
||||
let mut tokenizer = Tokenizer::new(Box::new(
|
||||
|
Reference in New Issue
Block a user