mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
BpeTrainer handles special tokens and limiting alphabet
This commit is contained in:
@ -27,6 +27,8 @@ pub struct BpeTrainer {
|
||||
pub min_frequency: u32,
|
||||
pub vocab_size: usize,
|
||||
pub show_progress: bool,
|
||||
pub special_tokens: Vec<String>,
|
||||
pub limit_alphabet: Option<usize>,
|
||||
}
|
||||
|
||||
impl Default for BpeTrainer {
|
||||
@ -35,6 +37,8 @@ impl Default for BpeTrainer {
|
||||
min_frequency: 0,
|
||||
vocab_size: 30000,
|
||||
show_progress: true,
|
||||
special_tokens: vec![],
|
||||
limit_alphabet: Some(100),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -47,6 +51,90 @@ impl BpeTrainer {
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Setup a progress bar if asked to show progress
|
||||
fn setup_progress(&self) -> Option<ProgressBar> {
|
||||
if self.show_progress {
|
||||
let p = ProgressBar::new(0);
|
||||
p.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"),
|
||||
);
|
||||
Some(p)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the progress bar in the finish state
|
||||
fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize) {
|
||||
if let Some(p) = p {
|
||||
p.set_length(final_len as u64);
|
||||
p.finish();
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the progress bar with the new provided length and message
|
||||
fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &str) {
|
||||
if let Some(p) = p {
|
||||
p.set_message(message);
|
||||
p.set_length(len as u64);
|
||||
p.reset();
|
||||
}
|
||||
}
|
||||
|
||||
/// Add the provided special tokens to the initial vocabulary
|
||||
fn add_special_tokens(&self, w2id: &mut HashMap<String, u32>, id2w: &mut Vec<String>) {
|
||||
for token in &self.special_tokens {
|
||||
if !w2id.contains_key(token) {
|
||||
id2w.push(token.to_owned());
|
||||
w2id.insert(token.to_owned(), (id2w.len() - 1) as u32);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the initial alphabet and limit it if relevant
|
||||
fn limit_alphabet(
|
||||
&self,
|
||||
wc: &HashMap<String, u32>,
|
||||
w2id: &mut HashMap<String, u32>,
|
||||
id2w: &mut Vec<String>,
|
||||
) {
|
||||
let mut alphabet: HashMap<char, usize> = HashMap::new();
|
||||
for (word, count) in wc {
|
||||
for c in word.chars() {
|
||||
alphabet
|
||||
.entry(c)
|
||||
.and_modify(|cnt| *cnt += *count as usize)
|
||||
.or_insert(*count as usize);
|
||||
}
|
||||
}
|
||||
|
||||
let to_remove = if let Some(limit) = self.limit_alphabet {
|
||||
alphabet.len() - limit
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let mut kept = alphabet.iter().collect::<Vec<_>>();
|
||||
kept.sort_unstable_by_key(|k| *k.1);
|
||||
|
||||
// Remove the unwanted chars
|
||||
if to_remove > 0 {
|
||||
kept.drain(..to_remove);
|
||||
}
|
||||
|
||||
// Keep the initial alphabet
|
||||
kept.sort_unstable_by_key(|k| (*k.0) as u32);
|
||||
kept.into_iter().for_each(|(c, _)| {
|
||||
let s = c.to_string();
|
||||
if !w2id.contains_key(&s) {
|
||||
id2w.push(s.clone());
|
||||
w2id.insert(s, (id2w.len() - 1) as u32);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Trainer for BpeTrainer {
|
||||
@ -57,32 +145,31 @@ impl Trainer for BpeTrainer {
|
||||
let mut word_to_id: HashMap<String, u32> = HashMap::new();
|
||||
let mut id_to_word: Vec<String> = vec![];
|
||||
|
||||
let progress = if self.show_progress {
|
||||
let p = ProgressBar::new(word_counts.len() as u64);
|
||||
p.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"),
|
||||
);
|
||||
p.set_message("Tokenize words");
|
||||
Some(p)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let progress = self.setup_progress();
|
||||
|
||||
//
|
||||
// 1. Tokenize words
|
||||
// 1. Add all special tokens to the vocabulary
|
||||
//
|
||||
self.add_special_tokens(&mut word_to_id, &mut id_to_word);
|
||||
|
||||
//
|
||||
// 2. Limit the initial alphabet if relevant
|
||||
//
|
||||
self.limit_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
|
||||
|
||||
//
|
||||
// 3. Tokenize words
|
||||
//
|
||||
self.update_progress(&progress, word_counts.len(), "Tokenize words");
|
||||
for (word, count) in &word_counts {
|
||||
let mut current_word = Word::new();
|
||||
counts.push(*count as i32);
|
||||
|
||||
for c in word.chars() {
|
||||
let s = c.to_string();
|
||||
if !word_to_id.contains_key(&s) {
|
||||
id_to_word.push(s.clone());
|
||||
word_to_id.insert(s.clone(), (id_to_word.len() - 1) as u32);
|
||||
if let Some(id) = word_to_id.get(&s) {
|
||||
current_word.add(*id);
|
||||
}
|
||||
current_word.add(word_to_id[&s]);
|
||||
}
|
||||
words.push(current_word);
|
||||
|
||||
@ -90,15 +177,12 @@ impl Trainer for BpeTrainer {
|
||||
p.inc(1);
|
||||
}
|
||||
}
|
||||
self.finalize_progress(&progress, words.len());
|
||||
|
||||
//
|
||||
// 2. Count pairs in words
|
||||
// 4. Count pairs in words
|
||||
//
|
||||
if let Some(p) = &progress {
|
||||
p.set_message("Count pairs");
|
||||
p.set_length(words.len() as u64);
|
||||
p.reset();
|
||||
}
|
||||
self.update_progress(&progress, words.len(), "Count pairs");
|
||||
let mut pair_counts: HashMap<Pair, (i32, Pair)> = HashMap::new();
|
||||
let mut where_to_update: HashMap<Pair, HashSet<usize>> = HashMap::new();
|
||||
for (index, word) in words.iter().enumerate() {
|
||||
@ -128,15 +212,12 @@ impl Trainer for BpeTrainer {
|
||||
p.inc(1);
|
||||
}
|
||||
}
|
||||
self.finalize_progress(&progress, words.len());
|
||||
|
||||
//
|
||||
// 3. Do merges
|
||||
// 5. Do merges
|
||||
//
|
||||
if let Some(p) = &progress {
|
||||
p.set_message("Compute merges");
|
||||
p.set_length(self.vocab_size as u64);
|
||||
p.reset();
|
||||
}
|
||||
self.update_progress(&progress, self.vocab_size, "Compute merges");
|
||||
let mut merges: Vec<(Pair, u32)> = vec![];
|
||||
loop {
|
||||
// Stop as soon as we have a big enough vocabulary
|
||||
@ -206,6 +287,7 @@ impl Trainer for BpeTrainer {
|
||||
p.inc(1);
|
||||
}
|
||||
}
|
||||
self.finalize_progress(&progress, merges.len());
|
||||
|
||||
Ok(Box::new(BPE::new(
|
||||
word_to_id.clone(),
|
||||
|
Reference in New Issue
Block a user