mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Handle special tokens
This commit is contained in:
@ -233,6 +233,10 @@ impl Tokenizer {
|
||||
Ok(self.tokenizer.add_tokens(&tokens))
|
||||
}
|
||||
|
||||
fn add_special_tokens(&mut self, tokens: Vec<&str>) -> PyResult<usize> {
|
||||
Ok(self.tokenizer.add_special_tokens(&tokens))
|
||||
}
|
||||
|
||||
fn train(&mut self, trainer: &Trainer, files: Vec<String>) -> PyResult<()> {
|
||||
trainer.trainer.execute(|trainer| {
|
||||
if let Err(e) = self.tokenizer.train(trainer, files) {
|
||||
|
@ -140,6 +140,7 @@ pub struct Tokenizer {
|
||||
added_tokens: HashMap<AddedToken, u32>,
|
||||
added_tokens_r: HashMap<u32, AddedToken>,
|
||||
split_re: Option<regex::Regex>,
|
||||
special_tokens: HashMap<String, u32>,
|
||||
|
||||
// General processing parameters
|
||||
trunc: Option<TruncationParams>,
|
||||
@ -159,6 +160,7 @@ impl Tokenizer {
|
||||
added_tokens: HashMap::new(),
|
||||
added_tokens_r: HashMap::new(),
|
||||
split_re: None,
|
||||
special_tokens: HashMap::new(),
|
||||
|
||||
trunc: None,
|
||||
padding: None,
|
||||
@ -483,6 +485,24 @@ impl Tokenizer {
|
||||
Ok(final_encoding)
|
||||
}
|
||||
|
||||
/// Register the given tokens as special tokens. This is especially useful for removing
|
||||
/// these special tokens while decoding
|
||||
pub fn add_special_tokens(&mut self, tokens: &[&str]) -> usize {
|
||||
let added_tokens = tokens
|
||||
.iter()
|
||||
.map(|t| AddedToken::from((*t).to_owned()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let added = self.add_tokens(&added_tokens);
|
||||
for token in tokens {
|
||||
if let Some(id) = self.token_to_id(token) {
|
||||
self.special_tokens.entry((*token).to_owned()).or_insert(id);
|
||||
}
|
||||
}
|
||||
|
||||
added
|
||||
}
|
||||
|
||||
/// Add the given tokens to the added vocabulary
|
||||
pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
||||
let mut ignored = 0;
|
||||
|
Reference in New Issue
Block a user