Handle special tokens

This commit is contained in:
Anthony MOI
2019-12-19 19:48:16 -05:00
parent 7f032b62df
commit a8d68d516d
2 changed files with 24 additions and 0 deletions

View File

@ -233,6 +233,10 @@ impl Tokenizer {
Ok(self.tokenizer.add_tokens(&tokens))
}
fn add_special_tokens(&mut self, tokens: Vec<&str>) -> PyResult<usize> {
Ok(self.tokenizer.add_special_tokens(&tokens))
}
fn train(&mut self, trainer: &Trainer, files: Vec<String>) -> PyResult<()> {
trainer.trainer.execute(|trainer| {
if let Err(e) = self.tokenizer.train(trainer, files) {

View File

@ -140,6 +140,7 @@ pub struct Tokenizer {
added_tokens: HashMap<AddedToken, u32>,
added_tokens_r: HashMap<u32, AddedToken>,
split_re: Option<regex::Regex>,
special_tokens: HashMap<String, u32>,
// General processing parameters
trunc: Option<TruncationParams>,
@ -159,6 +160,7 @@ impl Tokenizer {
added_tokens: HashMap::new(),
added_tokens_r: HashMap::new(),
split_re: None,
special_tokens: HashMap::new(),
trunc: None,
padding: None,
@ -483,6 +485,24 @@ impl Tokenizer {
Ok(final_encoding)
}
/// Register the given tokens as special tokens. This is especially useful for removing
/// these special tokens while decoding
pub fn add_special_tokens(&mut self, tokens: &[&str]) -> usize {
let added_tokens = tokens
.iter()
.map(|t| AddedToken::from((*t).to_owned()))
.collect::<Vec<_>>();
let added = self.add_tokens(&added_tokens);
for token in tokens {
if let Some(id) = self.token_to_id(token) {
self.special_tokens.entry((*token).to_owned()).or_insert(id);
}
}
added
}
/// Add the given tokens to the added vocabulary
pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
let mut ignored = 0;