mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Rust - Hotfix special tokens with wrong id
This commit is contained in:
@ -757,7 +757,7 @@ impl Tokenizer {
|
|||||||
let id = if let Some(id) = self.token_to_id(&token.content) {
|
let id = if let Some(id) = self.token_to_id(&token.content) {
|
||||||
id
|
id
|
||||||
} else {
|
} else {
|
||||||
let new_id = (self.model.get_vocab_size() + self.added_tokens.len()) as u32;
|
let new_id = (self.model.get_vocab_size() + self.added_tokens_map.len()) as u32;
|
||||||
self.added_tokens_map.insert(token.content.clone(), new_id);
|
self.added_tokens_map.insert(token.content.clone(), new_id);
|
||||||
|
|
||||||
if !self.special_tokens_set.contains(&token.content) {
|
if !self.special_tokens_set.contains(&token.content) {
|
||||||
|
@ -3,6 +3,31 @@ mod common;
|
|||||||
use common::*;
|
use common::*;
|
||||||
use tokenizers::tokenizer::{AddedToken, EncodeInput};
|
use tokenizers::tokenizer::{AddedToken, EncodeInput};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn add_tokens() {
|
||||||
|
let mut tokenizer = get_empty();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
tokenizer.add_special_tokens(&[
|
||||||
|
AddedToken::from("<cls>".into()),
|
||||||
|
AddedToken::from("<sep>".into())
|
||||||
|
]),
|
||||||
|
2
|
||||||
|
);
|
||||||
|
assert_eq!(tokenizer.token_to_id("<cls>"), Some(0));
|
||||||
|
assert_eq!(tokenizer.token_to_id("<sep>"), Some(1));
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
tokenizer.add_tokens(&[
|
||||||
|
AddedToken::from("hello".into()),
|
||||||
|
AddedToken::from("world".into())
|
||||||
|
]),
|
||||||
|
2
|
||||||
|
);
|
||||||
|
assert_eq!(tokenizer.token_to_id("hello"), Some(2));
|
||||||
|
assert_eq!(tokenizer.token_to_id("world"), Some(3));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lstrip_tokens() {
|
fn lstrip_tokens() {
|
||||||
let mut tokenizer = get_byte_level(true, false);
|
let mut tokenizer = get_byte_level(true, false);
|
||||||
|
@ -7,6 +7,11 @@ use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
|||||||
use tokenizers::processors::bert::BertProcessing;
|
use tokenizers::processors::bert::BertProcessing;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn get_empty() -> Tokenizer {
|
||||||
|
Tokenizer::new(Box::new(BPE::default()))
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
|
pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
|
||||||
let mut tokenizer = Tokenizer::new(Box::new(
|
let mut tokenizer = Tokenizer::new(Box::new(
|
||||||
|
Reference in New Issue
Block a user