BpeTrainer handles special tokens and limiting alphabet

This commit is contained in:
Anthony MOI
2020-01-01 12:54:41 -05:00
parent ebf22198f3
commit a7a5f9a67f

View File

@ -27,6 +27,8 @@ pub struct BpeTrainer {
pub min_frequency: u32,
pub vocab_size: usize,
pub show_progress: bool,
pub special_tokens: Vec<String>,
pub limit_alphabet: Option<usize>,
}
impl Default for BpeTrainer {
@ -35,6 +37,8 @@ impl Default for BpeTrainer {
min_frequency: 0,
vocab_size: 30000,
show_progress: true,
special_tokens: vec![],
limit_alphabet: Some(100),
}
}
}
@ -47,6 +51,90 @@ impl BpeTrainer {
..Default::default()
}
}
/// Setup a progress bar if asked to show progress
fn setup_progress(&self) -> Option<ProgressBar> {
if self.show_progress {
let p = ProgressBar::new(0);
p.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"),
);
Some(p)
} else {
None
}
}
/// Set the progress bar in the finish state
fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize) {
if let Some(p) = p {
p.set_length(final_len as u64);
p.finish();
println!();
}
}
/// Update the progress bar with the new provided length and message
fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &str) {
if let Some(p) = p {
p.set_message(message);
p.set_length(len as u64);
p.reset();
}
}
/// Add the provided special tokens to the initial vocabulary
fn add_special_tokens(&self, w2id: &mut HashMap<String, u32>, id2w: &mut Vec<String>) {
for token in &self.special_tokens {
if !w2id.contains_key(token) {
id2w.push(token.to_owned());
w2id.insert(token.to_owned(), (id2w.len() - 1) as u32);
}
}
}
/// Compute the initial alphabet and limit it if relevant
fn limit_alphabet(
&self,
wc: &HashMap<String, u32>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
) {
let mut alphabet: HashMap<char, usize> = HashMap::new();
for (word, count) in wc {
for c in word.chars() {
alphabet
.entry(c)
.and_modify(|cnt| *cnt += *count as usize)
.or_insert(*count as usize);
}
}
let to_remove = if let Some(limit) = self.limit_alphabet {
alphabet.len() - limit
} else {
0
};
let mut kept = alphabet.iter().collect::<Vec<_>>();
kept.sort_unstable_by_key(|k| *k.1);
// Remove the unwanted chars
if to_remove > 0 {
kept.drain(..to_remove);
}
// Keep the initial alphabet
kept.sort_unstable_by_key(|k| (*k.0) as u32);
kept.into_iter().for_each(|(c, _)| {
let s = c.to_string();
if !w2id.contains_key(&s) {
id2w.push(s.clone());
w2id.insert(s, (id2w.len() - 1) as u32);
}
});
}
}
impl Trainer for BpeTrainer {
@ -57,32 +145,31 @@ impl Trainer for BpeTrainer {
let mut word_to_id: HashMap<String, u32> = HashMap::new();
let mut id_to_word: Vec<String> = vec![];
let progress = if self.show_progress {
let p = ProgressBar::new(word_counts.len() as u64);
p.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"),
);
p.set_message("Tokenize words");
Some(p)
} else {
None
};
let progress = self.setup_progress();
//
// 1. Tokenize words
// 1. Add all special tokens to the vocabulary
//
self.add_special_tokens(&mut word_to_id, &mut id_to_word);
//
// 2. Limit the initial alphabet if relevant
//
self.limit_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
//
// 3. Tokenize words
//
self.update_progress(&progress, word_counts.len(), "Tokenize words");
for (word, count) in &word_counts {
let mut current_word = Word::new();
counts.push(*count as i32);
for c in word.chars() {
let s = c.to_string();
if !word_to_id.contains_key(&s) {
id_to_word.push(s.clone());
word_to_id.insert(s.clone(), (id_to_word.len() - 1) as u32);
if let Some(id) = word_to_id.get(&s) {
current_word.add(*id);
}
current_word.add(word_to_id[&s]);
}
words.push(current_word);
@ -90,15 +177,12 @@ impl Trainer for BpeTrainer {
p.inc(1);
}
}
self.finalize_progress(&progress, words.len());
//
// 2. Count pairs in words
// 4. Count pairs in words
//
if let Some(p) = &progress {
p.set_message("Count pairs");
p.set_length(words.len() as u64);
p.reset();
}
self.update_progress(&progress, words.len(), "Count pairs");
let mut pair_counts: HashMap<Pair, (i32, Pair)> = HashMap::new();
let mut where_to_update: HashMap<Pair, HashSet<usize>> = HashMap::new();
for (index, word) in words.iter().enumerate() {
@ -128,15 +212,12 @@ impl Trainer for BpeTrainer {
p.inc(1);
}
}
self.finalize_progress(&progress, words.len());
//
// 3. Do merges
// 5. Do merges
//
if let Some(p) = &progress {
p.set_message("Compute merges");
p.set_length(self.vocab_size as u64);
p.reset();
}
self.update_progress(&progress, self.vocab_size, "Compute merges");
let mut merges: Vec<(Pair, u32)> = vec![];
loop {
// Stop as soon as we have a big enough vocabulary
@ -206,6 +287,7 @@ impl Trainer for BpeTrainer {
p.inc(1);
}
}
self.finalize_progress(&progress, merges.len());
Ok(Box::new(BPE::new(
word_to_id.clone(),