BpeTrainer handles initial alphabet

This commit is contained in:
Anthony MOI
2020-01-02 15:01:22 -05:00
parent c8a5d2e32a
commit 7b12b3cca5

View File

@ -27,6 +27,7 @@ pub struct BpeTrainer {
pub show_progress: bool,
pub special_tokens: Vec<String>,
pub limit_alphabet: Option<usize>,
pub initial_alphabet: HashSet<char>,
}
impl Default for BpeTrainer {
@ -37,6 +38,7 @@ impl Default for BpeTrainer {
show_progress: true,
special_tokens: vec![],
limit_alphabet: None,
initial_alphabet: HashSet::new(),
}
}
}
@ -93,12 +95,13 @@ impl BpeTrainer {
}
/// Compute the initial alphabet and limit it if relevant
fn limit_alphabet(
fn compute_alphabet(
&self,
wc: &HashMap<String, u32>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
) {
// Compute the alphabet from seen words
let mut alphabet: HashMap<char, usize> = HashMap::new();
for (word, count) in wc {
for c in word.chars() {
@ -109,6 +112,19 @@ impl BpeTrainer {
}
}
// Also include anything from the provided initial alphabet
for c in &self.initial_alphabet {
alphabet
.entry(*c)
.and_modify(|cnt| *cnt = std::usize::MAX)
.or_insert(std::usize::MAX);
}
let mut kept = alphabet.iter().collect::<Vec<_>>();
// Compute the number of chars to remove from the alphabet
// If `limit_alphabet < initial_alphabet.len()`, some of these initial characters
// will be removed
let to_remove = self
.limit_alphabet
.map(|limit| {
@ -120,15 +136,13 @@ impl BpeTrainer {
})
.unwrap_or(0);
let mut kept = alphabet.iter().collect::<Vec<_>>();
kept.sort_unstable_by_key(|k| *k.1);
// Remove the unwanted chars
if to_remove > 0 {
kept.sort_unstable_by_key(|k| *k.1);
kept.drain(..to_remove);
}
// Keep the initial alphabet
// Keep the initial alphabet (sorted for determinism)
kept.sort_unstable_by_key(|k| (*k.0) as u32);
kept.into_iter().for_each(|(c, _)| {
let s = c.to_string();
@ -156,9 +170,9 @@ impl Trainer for BpeTrainer {
self.add_special_tokens(&mut word_to_id, &mut id_to_word);
//
// 2. Limit the initial alphabet if relevant
// 2. Compute the initial alphabet
//
self.limit_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
self.compute_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
//
// 3. Tokenize words