mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-06 12:48:18 +00:00
Rust - Extract AddedVocabulary management from Tokenizer
This commit is contained in:
@@ -433,8 +433,8 @@ impl Model for BPE {
|
|||||||
self.vocab.get(token).copied()
|
self.vocab.get(token).copied()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||||
self.vocab_r.get(&id).cloned()
|
self.vocab_r.get(&id).map(String::as_ref)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
|
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
|
||||||
|
|||||||
@@ -169,8 +169,8 @@ impl Model for WordLevel {
|
|||||||
self.vocab.get(token).copied()
|
self.vocab.get(token).copied()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||||
self.vocab_r.get(&id).cloned()
|
self.vocab_r.get(&id).map(String::as_ref)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_vocab(&self) -> &HashMap<String, u32> {
|
fn get_vocab(&self) -> &HashMap<String, u32> {
|
||||||
|
|||||||
@@ -283,8 +283,8 @@ impl Model for WordPiece {
|
|||||||
self.vocab.get(token).copied()
|
self.vocab.get(token).copied()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||||
self.vocab_r.get(&id).cloned()
|
self.vocab_r.get(&id).map(String::as_ref)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
|
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
|
||||||
|
|||||||
@@ -15,19 +15,20 @@ pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams,
|
|||||||
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
|
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
|
||||||
use indicatif::{ProgressBar, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressStyle};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::HashMap,
|
||||||
fs::File,
|
fs::File,
|
||||||
io::prelude::*,
|
io::prelude::*,
|
||||||
io::BufReader,
|
io::BufReader,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
mod added_vocabulary;
|
||||||
mod encoding;
|
mod encoding;
|
||||||
mod normalizer;
|
mod normalizer;
|
||||||
mod serialization;
|
mod serialization;
|
||||||
|
|
||||||
|
pub use added_vocabulary::*;
|
||||||
pub use encoding::*;
|
pub use encoding::*;
|
||||||
pub use normalizer::*;
|
pub use normalizer::*;
|
||||||
|
|
||||||
@@ -56,7 +57,7 @@ pub trait PreTokenizer: Send + Sync {
|
|||||||
pub trait Model: Send + Sync {
|
pub trait Model: Send + Sync {
|
||||||
fn tokenize(&self, tokens: Vec<(String, Offsets)>) -> Result<Vec<Token>>;
|
fn tokenize(&self, tokens: Vec<(String, Offsets)>) -> Result<Vec<Token>>;
|
||||||
fn token_to_id(&self, token: &str) -> Option<u32>;
|
fn token_to_id(&self, token: &str) -> Option<u32>;
|
||||||
fn id_to_token(&self, id: u32) -> Option<String>;
|
fn id_to_token(&self, id: u32) -> Option<&str>;
|
||||||
fn get_vocab(&self) -> &HashMap<String, u32>;
|
fn get_vocab(&self) -> &HashMap<String, u32>;
|
||||||
fn get_vocab_size(&self) -> usize;
|
fn get_vocab_size(&self) -> usize;
|
||||||
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>>;
|
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>>;
|
||||||
@@ -182,102 +183,6 @@ impl<I1: Into<InputSequence>, I2: Into<InputSequence>> From<(I1, I2)> for Encode
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct AddedToken {
|
|
||||||
/// The content of the added token
|
|
||||||
pub content: String,
|
|
||||||
/// Whether this token must be a single word or can break words
|
|
||||||
pub single_word: bool,
|
|
||||||
/// Whether this token should strip whitespaces on its left
|
|
||||||
pub lstrip: bool,
|
|
||||||
/// Whether this token should strip whitespaces on its right
|
|
||||||
pub rstrip: bool,
|
|
||||||
}
|
|
||||||
impl AddedToken {
|
|
||||||
pub fn from(content: String) -> Self {
|
|
||||||
AddedToken {
|
|
||||||
content,
|
|
||||||
..Default::default()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn single_word(mut self, single_word: bool) -> Self {
|
|
||||||
self.single_word = single_word;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
pub fn lstrip(mut self, lstrip: bool) -> Self {
|
|
||||||
self.lstrip = lstrip;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
pub fn rstrip(mut self, rstrip: bool) -> Self {
|
|
||||||
self.rstrip = rstrip;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
pub fn get_pattern(&self) -> String {
|
|
||||||
let mut r = if self.single_word {
|
|
||||||
let first_b = self
|
|
||||||
.content
|
|
||||||
.chars()
|
|
||||||
.next()
|
|
||||||
.map(|c| {
|
|
||||||
if regex_syntax::is_word_character(c) {
|
|
||||||
r"\b"
|
|
||||||
} else {
|
|
||||||
""
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
let last_b = self
|
|
||||||
.content
|
|
||||||
.chars()
|
|
||||||
.last()
|
|
||||||
.map(|c| {
|
|
||||||
if regex_syntax::is_word_character(c) {
|
|
||||||
r"\b"
|
|
||||||
} else {
|
|
||||||
""
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
format!(r"{}{}{}", first_b, regex::escape(&self.content), last_b)
|
|
||||||
} else {
|
|
||||||
regex::escape(&self.content)
|
|
||||||
};
|
|
||||||
|
|
||||||
if self.lstrip && self.rstrip {
|
|
||||||
r = format!(r"(\s)?{}(\s)?", r);
|
|
||||||
} else if self.lstrip {
|
|
||||||
r = format!(r"(\s)?{}", r);
|
|
||||||
} else if self.rstrip {
|
|
||||||
r = format!(r"{}(\s)?", r);
|
|
||||||
}
|
|
||||||
|
|
||||||
r
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl Default for AddedToken {
|
|
||||||
fn default() -> Self {
|
|
||||||
AddedToken {
|
|
||||||
content: String::new(),
|
|
||||||
single_word: false,
|
|
||||||
lstrip: false,
|
|
||||||
rstrip: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// We only want to hash on the content. AddedToken cannot be added multiple times with different
|
|
||||||
// options
|
|
||||||
impl std::hash::Hash for AddedToken {
|
|
||||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
|
||||||
self.content.hash(state);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl std::cmp::PartialEq for AddedToken {
|
|
||||||
fn eq(&self, other: &Self) -> bool {
|
|
||||||
self.content == other.content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl std::cmp::Eq for AddedToken {}
|
|
||||||
|
|
||||||
/// A `Tokenizer` is capable of encoding/decoding any text.
|
/// A `Tokenizer` is capable of encoding/decoding any text.
|
||||||
pub struct Tokenizer {
|
pub struct Tokenizer {
|
||||||
// Tokenizer parts
|
// Tokenizer parts
|
||||||
@@ -288,21 +193,7 @@ pub struct Tokenizer {
|
|||||||
decoder: Option<Box<dyn Decoder>>,
|
decoder: Option<Box<dyn Decoder>>,
|
||||||
|
|
||||||
// Added Vocabulary capabilities
|
// Added Vocabulary capabilities
|
||||||
/// Contains the mapping from String to ID as the user intended it. This map
|
added_vocabulary: AddedVocabulary,
|
||||||
/// contains both special tokens and classic added tokens.
|
|
||||||
added_tokens_map: HashMap<String, u32>,
|
|
||||||
/// Contains the mapping from ID to AddedToken for all the added tokens, both special
|
|
||||||
/// and classic.
|
|
||||||
added_tokens_map_r: HashMap<u32, AddedToken>,
|
|
||||||
/// Contains only the classic AddedToken, in the specific order the user gave them.
|
|
||||||
added_tokens: Vec<AddedToken>,
|
|
||||||
/// Contains only the special AddedToken, in the specific order the user gave them.
|
|
||||||
special_tokens: Vec<AddedToken>,
|
|
||||||
/// A Set, containing all the special token for easy access while decoding. This let's
|
|
||||||
/// use remove them easily with an O(1) complexity.
|
|
||||||
special_tokens_set: HashSet<String>,
|
|
||||||
/// A RegexSet containing all the patterns used to split on AddedTokens
|
|
||||||
split_re: regex::RegexSet,
|
|
||||||
|
|
||||||
// General processing parameters
|
// General processing parameters
|
||||||
truncation: Option<TruncationParams>,
|
truncation: Option<TruncationParams>,
|
||||||
@@ -327,12 +218,7 @@ impl Tokenizer {
|
|||||||
post_processor: None,
|
post_processor: None,
|
||||||
decoder: None,
|
decoder: None,
|
||||||
|
|
||||||
added_tokens_map: HashMap::new(),
|
added_vocabulary: AddedVocabulary::new(),
|
||||||
added_tokens_map_r: HashMap::new(),
|
|
||||||
added_tokens: vec![],
|
|
||||||
special_tokens: vec![],
|
|
||||||
special_tokens_set: HashSet::new(),
|
|
||||||
split_re: regex::RegexSet::new::<_, &&str>(&[]).unwrap(),
|
|
||||||
|
|
||||||
truncation: None,
|
truncation: None,
|
||||||
padding: None,
|
padding: None,
|
||||||
@@ -461,12 +347,15 @@ impl Tokenizer {
|
|||||||
pub fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
|
pub fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
|
||||||
let mut final_vocab = self.model.get_vocab().clone();
|
let mut final_vocab = self.model.get_vocab().clone();
|
||||||
|
|
||||||
if with_added_tokens && !self.added_tokens_map.is_empty() {
|
if with_added_tokens {
|
||||||
final_vocab.reserve(self.added_tokens_map.len());
|
let added_vocab = self.added_vocabulary.get_vocab();
|
||||||
for (token, id) in &self.added_tokens_map {
|
if !added_vocab.is_empty() {
|
||||||
|
final_vocab.reserve(added_vocab.len());
|
||||||
|
for (token, id) in added_vocab {
|
||||||
final_vocab.insert(token.clone(), *id);
|
final_vocab.insert(token.clone(), *id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
final_vocab
|
final_vocab
|
||||||
}
|
}
|
||||||
@@ -475,7 +364,7 @@ impl Tokenizer {
|
|||||||
pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
|
pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
|
||||||
self.model.get_vocab_size()
|
self.model.get_vocab_size()
|
||||||
+ if with_added_tokens {
|
+ if with_added_tokens {
|
||||||
self.added_tokens_map.len()
|
self.added_vocabulary.len()
|
||||||
} else {
|
} else {
|
||||||
0
|
0
|
||||||
}
|
}
|
||||||
@@ -483,26 +372,24 @@ impl Tokenizer {
|
|||||||
|
|
||||||
/// Converts a token in the corresponding id.
|
/// Converts a token in the corresponding id.
|
||||||
pub fn token_to_id(&self, token: &str) -> Option<u32> {
|
pub fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||||
if let Some(id) = self.added_tokens_map.get(token) {
|
self.added_vocabulary
|
||||||
Some(*id)
|
.token_to_id(token)
|
||||||
} else {
|
.copied()
|
||||||
self.model.token_to_id(token)
|
.or_else(|| self.model.token_to_id(token))
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts an id to the corresponding token.
|
/// Converts an id to the corresponding token.
|
||||||
pub fn id_to_token(&self, id: u32) -> Option<String> {
|
pub fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||||
if let Some(token) = self.added_tokens_map_r.get(&id) {
|
self.added_vocabulary
|
||||||
Some(token.content.clone())
|
.id_to_token(id)
|
||||||
} else {
|
.or_else(|| self.model.id_to_token(id))
|
||||||
self.model.id_to_token(id)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Normalize the given sentence and return the corresponding normalized string
|
/// Normalize the given sentence and return the corresponding normalized string
|
||||||
pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> {
|
pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> {
|
||||||
let mut normalized = self
|
let mut normalized = self
|
||||||
.split_on_added_tokens(sentence)
|
.added_vocabulary
|
||||||
|
.extract(sentence)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(sentence, id)| -> Result<NormalizedString> {
|
.map(|(sentence, id)| -> Result<NormalizedString> {
|
||||||
if id.is_some() {
|
if id.is_some() {
|
||||||
@@ -534,7 +421,7 @@ impl Tokenizer {
|
|||||||
|
|
||||||
let mut sequence_encodings = vec![];
|
let mut sequence_encodings = vec![];
|
||||||
for subseq in sequence {
|
for subseq in sequence {
|
||||||
let results = self.split_on_added_tokens(&subseq).into_iter().map(
|
let results = self.added_vocabulary.extract(&subseq).into_iter().map(
|
||||||
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
|
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
|
||||||
if let Some(id) = id {
|
if let Some(id) = id {
|
||||||
Ok((
|
Ok((
|
||||||
@@ -666,15 +553,17 @@ impl Tokenizer {
|
|||||||
let tokens = ids
|
let tokens = ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|id| {
|
.map(|id| {
|
||||||
let token = if let Some(token) = self.added_tokens_map_r.get(&id) {
|
let token = if let Some(token) = self.added_vocabulary.id_to_token(id) {
|
||||||
Some(token.content.to_owned())
|
Some(token)
|
||||||
} else {
|
} else {
|
||||||
self.model.id_to_token(id)
|
self.model.id_to_token(id)
|
||||||
};
|
};
|
||||||
|
|
||||||
token.filter(|token| {
|
token
|
||||||
!skip_special_tokens || !self.special_tokens_set.contains(token)
|
.filter(|token| {
|
||||||
|
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
|
||||||
})
|
})
|
||||||
|
.map(|t| t.to_owned())
|
||||||
})
|
})
|
||||||
.filter(|token| token.is_some())
|
.filter(|token| token.is_some())
|
||||||
.map(|id| id.unwrap())
|
.map(|id| id.unwrap())
|
||||||
@@ -863,174 +752,13 @@ impl Tokenizer {
|
|||||||
/// Register the given tokens as special tokens. This is especially useful for removing
|
/// Register the given tokens as special tokens. This is especially useful for removing
|
||||||
/// these special tokens while decoding
|
/// these special tokens while decoding
|
||||||
pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
||||||
for token in tokens {
|
self.added_vocabulary
|
||||||
if !self.special_tokens_set.contains(&token.content) {
|
.add_special_tokens(tokens, self.model.as_ref())
|
||||||
self.special_tokens.push(token.to_owned());
|
|
||||||
self.special_tokens_set.insert(token.content.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let added = self.add_tokens(&tokens);
|
|
||||||
|
|
||||||
self.refresh_added_tokens();
|
|
||||||
|
|
||||||
added
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add the given tokens to the added vocabulary
|
/// Add the given tokens to the added vocabulary
|
||||||
pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
||||||
let mut ignored = 0;
|
self.added_vocabulary
|
||||||
for token in tokens {
|
.add_tokens(tokens, self.model.as_ref())
|
||||||
if token.content.is_empty() {
|
|
||||||
ignored += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let id = if let Some(id) = self.token_to_id(&token.content) {
|
|
||||||
ignored += 1;
|
|
||||||
id
|
|
||||||
} else {
|
|
||||||
let new_id = (self.model.get_vocab_size() + self.added_tokens_map.len()) as u32;
|
|
||||||
self.added_tokens_map.insert(token.content.clone(), new_id);
|
|
||||||
|
|
||||||
if !self.special_tokens_set.contains(&token.content) {
|
|
||||||
self.added_tokens.push(token.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
new_id
|
|
||||||
};
|
|
||||||
|
|
||||||
// Update the current revert operation
|
|
||||||
self.added_tokens_map_r
|
|
||||||
.entry(id)
|
|
||||||
.and_modify(|t| *t = token.clone())
|
|
||||||
.or_insert_with(|| token.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
self.refresh_added_tokens();
|
|
||||||
|
|
||||||
// Return the number of added tokens
|
|
||||||
tokens.len() - ignored
|
|
||||||
}
|
|
||||||
|
|
||||||
fn refresh_added_tokens(&mut self) {
|
|
||||||
self.split_re = regex::RegexSet::new(
|
|
||||||
self.special_tokens
|
|
||||||
.iter()
|
|
||||||
.chain(self.added_tokens.iter())
|
|
||||||
.map(|token| token.get_pattern()),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Split the given sentence on multiple parts, finding the added tokens and their id in
|
|
||||||
/// the process
|
|
||||||
fn split_on_added_tokens(&self, sentence: &str) -> Vec<(NormalizedString, Option<u32>)> {
|
|
||||||
let mut matches = self
|
|
||||||
.split_re
|
|
||||||
.matches(sentence)
|
|
||||||
.into_iter()
|
|
||||||
.flat_map(|idx| {
|
|
||||||
regex::Regex::new(&self.split_re.patterns()[idx])
|
|
||||||
.unwrap()
|
|
||||||
.find_iter(&sentence)
|
|
||||||
.map(|m| (idx, (m.start(), m.end())))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
// We sort all the matches by their start and then by their pattern id
|
|
||||||
matches.sort_by(
|
|
||||||
|(idxa, (sa, _)), (idxb, (sb, _))| {
|
|
||||||
if sa != sb {
|
|
||||||
sa.cmp(sb)
|
|
||||||
} else {
|
|
||||||
idxa.cmp(idxb)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// Select the matches (if some are overlapping) we want to keep
|
|
||||||
let mut i = 0;
|
|
||||||
let mut current_offset = 0;
|
|
||||||
let mut splits = Vec::with_capacity(matches.len());
|
|
||||||
while i < matches.len() {
|
|
||||||
let (idx, (start, end)) = matches[i];
|
|
||||||
|
|
||||||
// current match is before the currentt offset, let's skip it
|
|
||||||
if start < current_offset {
|
|
||||||
i += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find out if we have overlapping neighbors. If so, we keep the one with the lowest
|
|
||||||
// idx, and apply it, then continue. All others will be skipped since `current_offset`
|
|
||||||
// will have been increased
|
|
||||||
if i + 1 < matches.len() {
|
|
||||||
if let Some((idx, (s, e))) = matches[i..]
|
|
||||||
.iter()
|
|
||||||
.take_while(|(_, (s, e))| *s < end && start < *e)
|
|
||||||
.min() // Order on idx first
|
|
||||||
.copied()
|
|
||||||
{
|
|
||||||
splits.push((idx, (s, e)));
|
|
||||||
current_offset = e;
|
|
||||||
i += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We didn't find overlapping neighbors, apply ourself
|
|
||||||
splits.push((idx, (start, end)));
|
|
||||||
current_offset = end;
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We also insert the splits that are inbetween the added tokens, to split the entire string
|
|
||||||
let mut start_offset = 0;
|
|
||||||
let mut splits = splits
|
|
||||||
.into_iter()
|
|
||||||
.flat_map(|(idx, (start, end))| {
|
|
||||||
let mut splits = vec![];
|
|
||||||
if start_offset < start {
|
|
||||||
splits.push((None, (start_offset, start)));
|
|
||||||
}
|
|
||||||
splits.push((Some(idx), (start, end)));
|
|
||||||
start_offset = end;
|
|
||||||
|
|
||||||
splits
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
if let Some((_, (_, end))) = splits.iter().last().copied() {
|
|
||||||
if end < sentence.len() {
|
|
||||||
splits.push((None, (end, sentence.len())));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if splits.is_empty() {
|
|
||||||
vec![(NormalizedString::from(sentence), None)]
|
|
||||||
} else {
|
|
||||||
splits
|
|
||||||
.into_iter()
|
|
||||||
.map(|(idx, (start, end))| unsafe {
|
|
||||||
let s = sentence.get_unchecked(start..end).to_owned();
|
|
||||||
let normalized = NormalizedString::from(&s);
|
|
||||||
|
|
||||||
// Find out the associated AddedToken, and its id
|
|
||||||
let id = if let Some(idx) = idx {
|
|
||||||
let added = if idx >= self.special_tokens.len() {
|
|
||||||
&self.added_tokens[idx - self.special_tokens.len()]
|
|
||||||
} else {
|
|
||||||
&self.special_tokens[idx]
|
|
||||||
};
|
|
||||||
|
|
||||||
self.token_to_id(&added.content)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
(normalized, id)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use super::{AddedToken, Tokenizer};
|
use super::{added_vocabulary::AddedTokenWithId, Tokenizer};
|
||||||
use crate::models::bpe::BPE;
|
use crate::models::bpe::BPE;
|
||||||
use serde::{
|
use serde::{
|
||||||
self,
|
self,
|
||||||
@@ -9,18 +9,6 @@ use serde::{
|
|||||||
|
|
||||||
static SERIALIZATION_VERSION: &str = "1.0";
|
static SERIALIZATION_VERSION: &str = "1.0";
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct AddedTokenWithId {
|
|
||||||
/// The id assigned to this token
|
|
||||||
id: u32,
|
|
||||||
/// Whether this is a special token
|
|
||||||
special: bool,
|
|
||||||
|
|
||||||
#[serde(flatten)]
|
|
||||||
/// The target AddedToken
|
|
||||||
token: AddedToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Serialize for Tokenizer {
|
impl Serialize for Tokenizer {
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
@@ -36,18 +24,7 @@ impl Serialize for Tokenizer {
|
|||||||
tokenizer.serialize_field("padding", &self.padding)?;
|
tokenizer.serialize_field("padding", &self.padding)?;
|
||||||
|
|
||||||
// Added tokens
|
// Added tokens
|
||||||
let mut added_tokens = self
|
tokenizer.serialize_field("added_tokens", &self.added_vocabulary)?;
|
||||||
.added_tokens_map_r
|
|
||||||
.iter()
|
|
||||||
.map(|(id, token)| AddedTokenWithId {
|
|
||||||
id: *id,
|
|
||||||
special: self.special_tokens_set.contains(&token.content),
|
|
||||||
token: token.clone(),
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
// We need to have these added tokens ordered by ascending ID
|
|
||||||
added_tokens.sort_unstable_by_key(|o| o.id);
|
|
||||||
tokenizer.serialize_field("added_tokens", &added_tokens)?;
|
|
||||||
|
|
||||||
// Then add our parts
|
// Then add our parts
|
||||||
tokenizer.serialize_field("normalizer", &self.normalizer)?;
|
tokenizer.serialize_field("normalizer", &self.normalizer)?;
|
||||||
@@ -141,6 +118,8 @@ impl<'de> Visitor<'de> for TokenizerVisitor {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We take care of deserializing the added_tokens (instead of `AddedVocabulary` directly
|
||||||
|
// because it let us check that associated IDs are still good, and warn the user otherwise
|
||||||
for token in tokens {
|
for token in tokens {
|
||||||
let tk = token.token.content.clone();
|
let tk = token.token.content.clone();
|
||||||
if token.special {
|
if token.special {
|
||||||
|
|||||||
Reference in New Issue
Block a user