Rust - Hotfix special tokens with wrong id

This commit is contained in:
Anthony MOI
2020-04-01 11:02:55 -04:00
parent 93a83127ae
commit b770f36428
3 changed files with 31 additions and 1 deletions

View File

@ -757,7 +757,7 @@ impl Tokenizer {
let id = if let Some(id) = self.token_to_id(&token.content) { let id = if let Some(id) = self.token_to_id(&token.content) {
id id
} else { } else {
let new_id = (self.model.get_vocab_size() + self.added_tokens.len()) as u32; let new_id = (self.model.get_vocab_size() + self.added_tokens_map.len()) as u32;
self.added_tokens_map.insert(token.content.clone(), new_id); self.added_tokens_map.insert(token.content.clone(), new_id);
if !self.special_tokens_set.contains(&token.content) { if !self.special_tokens_set.contains(&token.content) {

View File

@ -3,6 +3,31 @@ mod common;
use common::*; use common::*;
use tokenizers::tokenizer::{AddedToken, EncodeInput}; use tokenizers::tokenizer::{AddedToken, EncodeInput};
#[test]
fn add_tokens() {
let mut tokenizer = get_empty();
assert_eq!(
tokenizer.add_special_tokens(&[
AddedToken::from("<cls>".into()),
AddedToken::from("<sep>".into())
]),
2
);
assert_eq!(tokenizer.token_to_id("<cls>"), Some(0));
assert_eq!(tokenizer.token_to_id("<sep>"), Some(1));
assert_eq!(
tokenizer.add_tokens(&[
AddedToken::from("hello".into()),
AddedToken::from("world".into())
]),
2
);
assert_eq!(tokenizer.token_to_id("hello"), Some(2));
assert_eq!(tokenizer.token_to_id("world"), Some(3));
}
#[test] #[test]
fn lstrip_tokens() { fn lstrip_tokens() {
let mut tokenizer = get_byte_level(true, false); let mut tokenizer = get_byte_level(true, false);

View File

@ -7,6 +7,11 @@ use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::processors::bert::BertProcessing; use tokenizers::processors::bert::BertProcessing;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
#[allow(dead_code)]
pub fn get_empty() -> Tokenizer {
Tokenizer::new(Box::new(BPE::default()))
}
#[allow(dead_code)] #[allow(dead_code)]
pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer { pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
let mut tokenizer = Tokenizer::new(Box::new( let mut tokenizer = Tokenizer::new(Box::new(