BpeTrainer handles special tokens and limiting alphabet

This commit is contained in:
Anthony MOI
2020-01-01 12:54:41 -05:00
parent ebf22198f3
commit a7a5f9a67f

View File

@ -27,6 +27,8 @@ pub struct BpeTrainer {
pub min_frequency: u32, pub min_frequency: u32,
pub vocab_size: usize, pub vocab_size: usize,
pub show_progress: bool, pub show_progress: bool,
pub special_tokens: Vec<String>,
pub limit_alphabet: Option<usize>,
} }
impl Default for BpeTrainer { impl Default for BpeTrainer {
@ -35,6 +37,8 @@ impl Default for BpeTrainer {
min_frequency: 0, min_frequency: 0,
vocab_size: 30000, vocab_size: 30000,
show_progress: true, show_progress: true,
special_tokens: vec![],
limit_alphabet: Some(100),
} }
} }
} }
@ -47,6 +51,90 @@ impl BpeTrainer {
..Default::default() ..Default::default()
} }
} }
/// Setup a progress bar if asked to show progress
fn setup_progress(&self) -> Option<ProgressBar> {
if self.show_progress {
let p = ProgressBar::new(0);
p.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"),
);
Some(p)
} else {
None
}
}
/// Set the progress bar in the finish state
fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize) {
if let Some(p) = p {
p.set_length(final_len as u64);
p.finish();
println!();
}
}
/// Update the progress bar with the new provided length and message
fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &str) {
if let Some(p) = p {
p.set_message(message);
p.set_length(len as u64);
p.reset();
}
}
/// Add the provided special tokens to the initial vocabulary
fn add_special_tokens(&self, w2id: &mut HashMap<String, u32>, id2w: &mut Vec<String>) {
for token in &self.special_tokens {
if !w2id.contains_key(token) {
id2w.push(token.to_owned());
w2id.insert(token.to_owned(), (id2w.len() - 1) as u32);
}
}
}
/// Compute the initial alphabet and limit it if relevant
fn limit_alphabet(
&self,
wc: &HashMap<String, u32>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
) {
let mut alphabet: HashMap<char, usize> = HashMap::new();
for (word, count) in wc {
for c in word.chars() {
alphabet
.entry(c)
.and_modify(|cnt| *cnt += *count as usize)
.or_insert(*count as usize);
}
}
let to_remove = if let Some(limit) = self.limit_alphabet {
alphabet.len() - limit
} else {
0
};
let mut kept = alphabet.iter().collect::<Vec<_>>();
kept.sort_unstable_by_key(|k| *k.1);
// Remove the unwanted chars
if to_remove > 0 {
kept.drain(..to_remove);
}
// Keep the initial alphabet
kept.sort_unstable_by_key(|k| (*k.0) as u32);
kept.into_iter().for_each(|(c, _)| {
let s = c.to_string();
if !w2id.contains_key(&s) {
id2w.push(s.clone());
w2id.insert(s, (id2w.len() - 1) as u32);
}
});
}
} }
impl Trainer for BpeTrainer { impl Trainer for BpeTrainer {
@ -57,32 +145,31 @@ impl Trainer for BpeTrainer {
let mut word_to_id: HashMap<String, u32> = HashMap::new(); let mut word_to_id: HashMap<String, u32> = HashMap::new();
let mut id_to_word: Vec<String> = vec![]; let mut id_to_word: Vec<String> = vec![];
let progress = if self.show_progress { let progress = self.setup_progress();
let p = ProgressBar::new(word_counts.len() as u64);
p.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"),
);
p.set_message("Tokenize words");
Some(p)
} else {
None
};
// //
// 1. Tokenize words // 1. Add all special tokens to the vocabulary
// //
self.add_special_tokens(&mut word_to_id, &mut id_to_word);
//
// 2. Limit the initial alphabet if relevant
//
self.limit_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
//
// 3. Tokenize words
//
self.update_progress(&progress, word_counts.len(), "Tokenize words");
for (word, count) in &word_counts { for (word, count) in &word_counts {
let mut current_word = Word::new(); let mut current_word = Word::new();
counts.push(*count as i32); counts.push(*count as i32);
for c in word.chars() { for c in word.chars() {
let s = c.to_string(); let s = c.to_string();
if !word_to_id.contains_key(&s) { if let Some(id) = word_to_id.get(&s) {
id_to_word.push(s.clone()); current_word.add(*id);
word_to_id.insert(s.clone(), (id_to_word.len() - 1) as u32);
} }
current_word.add(word_to_id[&s]);
} }
words.push(current_word); words.push(current_word);
@ -90,15 +177,12 @@ impl Trainer for BpeTrainer {
p.inc(1); p.inc(1);
} }
} }
self.finalize_progress(&progress, words.len());
// //
// 2. Count pairs in words // 4. Count pairs in words
// //
if let Some(p) = &progress { self.update_progress(&progress, words.len(), "Count pairs");
p.set_message("Count pairs");
p.set_length(words.len() as u64);
p.reset();
}
let mut pair_counts: HashMap<Pair, (i32, Pair)> = HashMap::new(); let mut pair_counts: HashMap<Pair, (i32, Pair)> = HashMap::new();
let mut where_to_update: HashMap<Pair, HashSet<usize>> = HashMap::new(); let mut where_to_update: HashMap<Pair, HashSet<usize>> = HashMap::new();
for (index, word) in words.iter().enumerate() { for (index, word) in words.iter().enumerate() {
@ -128,15 +212,12 @@ impl Trainer for BpeTrainer {
p.inc(1); p.inc(1);
} }
} }
self.finalize_progress(&progress, words.len());
// //
// 3. Do merges // 5. Do merges
// //
if let Some(p) = &progress { self.update_progress(&progress, self.vocab_size, "Compute merges");
p.set_message("Compute merges");
p.set_length(self.vocab_size as u64);
p.reset();
}
let mut merges: Vec<(Pair, u32)> = vec![]; let mut merges: Vec<(Pair, u32)> = vec![];
loop { loop {
// Stop as soon as we have a big enough vocabulary // Stop as soon as we have a big enough vocabulary
@ -206,6 +287,7 @@ impl Trainer for BpeTrainer {
p.inc(1); p.inc(1);
} }
} }
self.finalize_progress(&progress, merges.len());
Ok(Box::new(BPE::new( Ok(Box::new(BPE::new(
word_to_id.clone(), word_to_id.clone(),