diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index cefc03a5..e4d4342c 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -27,6 +27,8 @@ pub struct BpeTrainer { pub min_frequency: u32, pub vocab_size: usize, pub show_progress: bool, + pub special_tokens: Vec, + pub limit_alphabet: Option, } impl Default for BpeTrainer { @@ -35,6 +37,8 @@ impl Default for BpeTrainer { min_frequency: 0, vocab_size: 30000, show_progress: true, + special_tokens: vec![], + limit_alphabet: Some(100), } } } @@ -47,6 +51,90 @@ impl BpeTrainer { ..Default::default() } } + + /// Setup a progress bar if asked to show progress + fn setup_progress(&self) -> Option { + if self.show_progress { + let p = ProgressBar::new(0); + p.set_style( + ProgressStyle::default_bar() + .template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"), + ); + Some(p) + } else { + None + } + } + + /// Set the progress bar in the finish state + fn finalize_progress(&self, p: &Option, final_len: usize) { + if let Some(p) = p { + p.set_length(final_len as u64); + p.finish(); + println!(); + } + } + + /// Update the progress bar with the new provided length and message + fn update_progress(&self, p: &Option, len: usize, message: &str) { + if let Some(p) = p { + p.set_message(message); + p.set_length(len as u64); + p.reset(); + } + } + + /// Add the provided special tokens to the initial vocabulary + fn add_special_tokens(&self, w2id: &mut HashMap, id2w: &mut Vec) { + for token in &self.special_tokens { + if !w2id.contains_key(token) { + id2w.push(token.to_owned()); + w2id.insert(token.to_owned(), (id2w.len() - 1) as u32); + } + } + } + + /// Compute the initial alphabet and limit it if relevant + fn limit_alphabet( + &self, + wc: &HashMap, + w2id: &mut HashMap, + id2w: &mut Vec, + ) { + let mut alphabet: HashMap = HashMap::new(); + for (word, count) in wc { + for c in word.chars() { + alphabet + .entry(c) + .and_modify(|cnt| *cnt += *count as usize) + .or_insert(*count as usize); + } + } + + let to_remove = if let Some(limit) = self.limit_alphabet { + alphabet.len() - limit + } else { + 0 + }; + + let mut kept = alphabet.iter().collect::>(); + kept.sort_unstable_by_key(|k| *k.1); + + // Remove the unwanted chars + if to_remove > 0 { + kept.drain(..to_remove); + } + + // Keep the initial alphabet + kept.sort_unstable_by_key(|k| (*k.0) as u32); + kept.into_iter().for_each(|(c, _)| { + let s = c.to_string(); + if !w2id.contains_key(&s) { + id2w.push(s.clone()); + w2id.insert(s, (id2w.len() - 1) as u32); + } + }); + } } impl Trainer for BpeTrainer { @@ -57,32 +145,31 @@ impl Trainer for BpeTrainer { let mut word_to_id: HashMap = HashMap::new(); let mut id_to_word: Vec = vec![]; - let progress = if self.show_progress { - let p = ProgressBar::new(word_counts.len() as u64); - p.set_style( - ProgressStyle::default_bar() - .template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"), - ); - p.set_message("Tokenize words"); - Some(p) - } else { - None - }; + let progress = self.setup_progress(); // - // 1. Tokenize words + // 1. Add all special tokens to the vocabulary // + self.add_special_tokens(&mut word_to_id, &mut id_to_word); + + // + // 2. Limit the initial alphabet if relevant + // + self.limit_alphabet(&word_counts, &mut word_to_id, &mut id_to_word); + + // + // 3. Tokenize words + // + self.update_progress(&progress, word_counts.len(), "Tokenize words"); for (word, count) in &word_counts { let mut current_word = Word::new(); counts.push(*count as i32); for c in word.chars() { let s = c.to_string(); - if !word_to_id.contains_key(&s) { - id_to_word.push(s.clone()); - word_to_id.insert(s.clone(), (id_to_word.len() - 1) as u32); + if let Some(id) = word_to_id.get(&s) { + current_word.add(*id); } - current_word.add(word_to_id[&s]); } words.push(current_word); @@ -90,15 +177,12 @@ impl Trainer for BpeTrainer { p.inc(1); } } + self.finalize_progress(&progress, words.len()); // - // 2. Count pairs in words + // 4. Count pairs in words // - if let Some(p) = &progress { - p.set_message("Count pairs"); - p.set_length(words.len() as u64); - p.reset(); - } + self.update_progress(&progress, words.len(), "Count pairs"); let mut pair_counts: HashMap = HashMap::new(); let mut where_to_update: HashMap> = HashMap::new(); for (index, word) in words.iter().enumerate() { @@ -128,15 +212,12 @@ impl Trainer for BpeTrainer { p.inc(1); } } + self.finalize_progress(&progress, words.len()); // - // 3. Do merges + // 5. Do merges // - if let Some(p) = &progress { - p.set_message("Compute merges"); - p.set_length(self.vocab_size as u64); - p.reset(); - } + self.update_progress(&progress, self.vocab_size, "Compute merges"); let mut merges: Vec<(Pair, u32)> = vec![]; loop { // Stop as soon as we have a big enough vocabulary @@ -206,6 +287,7 @@ impl Trainer for BpeTrainer { p.inc(1); } } + self.finalize_progress(&progress, merges.len()); Ok(Box::new(BPE::new( word_to_id.clone(),