mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
BpeTrainer handles initial alphabet
This commit is contained in:
@ -27,6 +27,7 @@ pub struct BpeTrainer {
|
|||||||
pub show_progress: bool,
|
pub show_progress: bool,
|
||||||
pub special_tokens: Vec<String>,
|
pub special_tokens: Vec<String>,
|
||||||
pub limit_alphabet: Option<usize>,
|
pub limit_alphabet: Option<usize>,
|
||||||
|
pub initial_alphabet: HashSet<char>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for BpeTrainer {
|
impl Default for BpeTrainer {
|
||||||
@ -37,6 +38,7 @@ impl Default for BpeTrainer {
|
|||||||
show_progress: true,
|
show_progress: true,
|
||||||
special_tokens: vec![],
|
special_tokens: vec![],
|
||||||
limit_alphabet: None,
|
limit_alphabet: None,
|
||||||
|
initial_alphabet: HashSet::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -93,12 +95,13 @@ impl BpeTrainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Compute the initial alphabet and limit it if relevant
|
/// Compute the initial alphabet and limit it if relevant
|
||||||
fn limit_alphabet(
|
fn compute_alphabet(
|
||||||
&self,
|
&self,
|
||||||
wc: &HashMap<String, u32>,
|
wc: &HashMap<String, u32>,
|
||||||
w2id: &mut HashMap<String, u32>,
|
w2id: &mut HashMap<String, u32>,
|
||||||
id2w: &mut Vec<String>,
|
id2w: &mut Vec<String>,
|
||||||
) {
|
) {
|
||||||
|
// Compute the alphabet from seen words
|
||||||
let mut alphabet: HashMap<char, usize> = HashMap::new();
|
let mut alphabet: HashMap<char, usize> = HashMap::new();
|
||||||
for (word, count) in wc {
|
for (word, count) in wc {
|
||||||
for c in word.chars() {
|
for c in word.chars() {
|
||||||
@ -109,6 +112,19 @@ impl BpeTrainer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Also include anything from the provided initial alphabet
|
||||||
|
for c in &self.initial_alphabet {
|
||||||
|
alphabet
|
||||||
|
.entry(*c)
|
||||||
|
.and_modify(|cnt| *cnt = std::usize::MAX)
|
||||||
|
.or_insert(std::usize::MAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut kept = alphabet.iter().collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Compute the number of chars to remove from the alphabet
|
||||||
|
// If `limit_alphabet < initial_alphabet.len()`, some of these initial characters
|
||||||
|
// will be removed
|
||||||
let to_remove = self
|
let to_remove = self
|
||||||
.limit_alphabet
|
.limit_alphabet
|
||||||
.map(|limit| {
|
.map(|limit| {
|
||||||
@ -120,15 +136,13 @@ impl BpeTrainer {
|
|||||||
})
|
})
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
|
|
||||||
let mut kept = alphabet.iter().collect::<Vec<_>>();
|
|
||||||
kept.sort_unstable_by_key(|k| *k.1);
|
|
||||||
|
|
||||||
// Remove the unwanted chars
|
// Remove the unwanted chars
|
||||||
if to_remove > 0 {
|
if to_remove > 0 {
|
||||||
|
kept.sort_unstable_by_key(|k| *k.1);
|
||||||
kept.drain(..to_remove);
|
kept.drain(..to_remove);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep the initial alphabet
|
// Keep the initial alphabet (sorted for determinism)
|
||||||
kept.sort_unstable_by_key(|k| (*k.0) as u32);
|
kept.sort_unstable_by_key(|k| (*k.0) as u32);
|
||||||
kept.into_iter().for_each(|(c, _)| {
|
kept.into_iter().for_each(|(c, _)| {
|
||||||
let s = c.to_string();
|
let s = c.to_string();
|
||||||
@ -156,9 +170,9 @@ impl Trainer for BpeTrainer {
|
|||||||
self.add_special_tokens(&mut word_to_id, &mut id_to_word);
|
self.add_special_tokens(&mut word_to_id, &mut id_to_word);
|
||||||
|
|
||||||
//
|
//
|
||||||
// 2. Limit the initial alphabet if relevant
|
// 2. Compute the initial alphabet
|
||||||
//
|
//
|
||||||
self.limit_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
|
self.compute_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
|
||||||
|
|
||||||
//
|
//
|
||||||
// 3. Tokenize words
|
// 3. Tokenize words
|
||||||
|
Reference in New Issue
Block a user