mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
BpeTrainer handles initial alphabet
This commit is contained in:
@ -27,6 +27,7 @@ pub struct BpeTrainer {
|
||||
pub show_progress: bool,
|
||||
pub special_tokens: Vec<String>,
|
||||
pub limit_alphabet: Option<usize>,
|
||||
pub initial_alphabet: HashSet<char>,
|
||||
}
|
||||
|
||||
impl Default for BpeTrainer {
|
||||
@ -37,6 +38,7 @@ impl Default for BpeTrainer {
|
||||
show_progress: true,
|
||||
special_tokens: vec![],
|
||||
limit_alphabet: None,
|
||||
initial_alphabet: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -93,12 +95,13 @@ impl BpeTrainer {
|
||||
}
|
||||
|
||||
/// Compute the initial alphabet and limit it if relevant
|
||||
fn limit_alphabet(
|
||||
fn compute_alphabet(
|
||||
&self,
|
||||
wc: &HashMap<String, u32>,
|
||||
w2id: &mut HashMap<String, u32>,
|
||||
id2w: &mut Vec<String>,
|
||||
) {
|
||||
// Compute the alphabet from seen words
|
||||
let mut alphabet: HashMap<char, usize> = HashMap::new();
|
||||
for (word, count) in wc {
|
||||
for c in word.chars() {
|
||||
@ -109,6 +112,19 @@ impl BpeTrainer {
|
||||
}
|
||||
}
|
||||
|
||||
// Also include anything from the provided initial alphabet
|
||||
for c in &self.initial_alphabet {
|
||||
alphabet
|
||||
.entry(*c)
|
||||
.and_modify(|cnt| *cnt = std::usize::MAX)
|
||||
.or_insert(std::usize::MAX);
|
||||
}
|
||||
|
||||
let mut kept = alphabet.iter().collect::<Vec<_>>();
|
||||
|
||||
// Compute the number of chars to remove from the alphabet
|
||||
// If `limit_alphabet < initial_alphabet.len()`, some of these initial characters
|
||||
// will be removed
|
||||
let to_remove = self
|
||||
.limit_alphabet
|
||||
.map(|limit| {
|
||||
@ -120,15 +136,13 @@ impl BpeTrainer {
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
let mut kept = alphabet.iter().collect::<Vec<_>>();
|
||||
kept.sort_unstable_by_key(|k| *k.1);
|
||||
|
||||
// Remove the unwanted chars
|
||||
if to_remove > 0 {
|
||||
kept.sort_unstable_by_key(|k| *k.1);
|
||||
kept.drain(..to_remove);
|
||||
}
|
||||
|
||||
// Keep the initial alphabet
|
||||
// Keep the initial alphabet (sorted for determinism)
|
||||
kept.sort_unstable_by_key(|k| (*k.0) as u32);
|
||||
kept.into_iter().for_each(|(c, _)| {
|
||||
let s = c.to_string();
|
||||
@ -156,9 +170,9 @@ impl Trainer for BpeTrainer {
|
||||
self.add_special_tokens(&mut word_to_id, &mut id_to_word);
|
||||
|
||||
//
|
||||
// 2. Limit the initial alphabet if relevant
|
||||
// 2. Compute the initial alphabet
|
||||
//
|
||||
self.limit_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
|
||||
self.compute_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
|
||||
|
||||
//
|
||||
// 3. Tokenize words
|
||||
|
Reference in New Issue
Block a user