Oops - Fix trainer

This commit is contained in:
Anthony MOI
2020-01-01 13:36:42 -05:00
parent a7a5f9a67f
commit a5c5e5840f

View File

@ -16,9 +16,7 @@ use std::collections::{HashMap, HashSet};
///
/// let word_counts: HashMap<String, u32> = [
/// (String::from("Hello"), 1),
/// (String::from(","), 1),
/// (String::from("ĠWorld"), 1),
/// (String::from("!"), 1),
/// (String::from("World"), 1),
/// ].iter().cloned().collect();
/// let trainer = BpeTrainer::default();
/// let model = trainer.train(word_counts);
@ -38,7 +36,7 @@ impl Default for BpeTrainer {
vocab_size: 30000,
show_progress: true,
special_tokens: vec![],
limit_alphabet: Some(100),
limit_alphabet: None,
}
}
}
@ -111,11 +109,16 @@ impl BpeTrainer {
}
}
let to_remove = if let Some(limit) = self.limit_alphabet {
alphabet.len() - limit
} else {
0
};
let to_remove = self
.limit_alphabet
.map(|limit| {
if alphabet.len() > limit {
alphabet.len() - limit
} else {
0
}
})
.unwrap_or(0);
let mut kept = alphabet.iter().collect::<Vec<_>>();
kept.sort_unstable_by_key(|k| *k.1);