mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
BpeTrainer handles special tokens and limiting alphabet
This commit is contained in:
@ -27,6 +27,8 @@ pub struct BpeTrainer {
|
|||||||
pub min_frequency: u32,
|
pub min_frequency: u32,
|
||||||
pub vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
pub show_progress: bool,
|
pub show_progress: bool,
|
||||||
|
pub special_tokens: Vec<String>,
|
||||||
|
pub limit_alphabet: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for BpeTrainer {
|
impl Default for BpeTrainer {
|
||||||
@ -35,6 +37,8 @@ impl Default for BpeTrainer {
|
|||||||
min_frequency: 0,
|
min_frequency: 0,
|
||||||
vocab_size: 30000,
|
vocab_size: 30000,
|
||||||
show_progress: true,
|
show_progress: true,
|
||||||
|
special_tokens: vec![],
|
||||||
|
limit_alphabet: Some(100),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -47,6 +51,90 @@ impl BpeTrainer {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Setup a progress bar if asked to show progress
|
||||||
|
fn setup_progress(&self) -> Option<ProgressBar> {
|
||||||
|
if self.show_progress {
|
||||||
|
let p = ProgressBar::new(0);
|
||||||
|
p.set_style(
|
||||||
|
ProgressStyle::default_bar()
|
||||||
|
.template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"),
|
||||||
|
);
|
||||||
|
Some(p)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the progress bar in the finish state
|
||||||
|
fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize) {
|
||||||
|
if let Some(p) = p {
|
||||||
|
p.set_length(final_len as u64);
|
||||||
|
p.finish();
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the progress bar with the new provided length and message
|
||||||
|
fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &str) {
|
||||||
|
if let Some(p) = p {
|
||||||
|
p.set_message(message);
|
||||||
|
p.set_length(len as u64);
|
||||||
|
p.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add the provided special tokens to the initial vocabulary
|
||||||
|
fn add_special_tokens(&self, w2id: &mut HashMap<String, u32>, id2w: &mut Vec<String>) {
|
||||||
|
for token in &self.special_tokens {
|
||||||
|
if !w2id.contains_key(token) {
|
||||||
|
id2w.push(token.to_owned());
|
||||||
|
w2id.insert(token.to_owned(), (id2w.len() - 1) as u32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the initial alphabet and limit it if relevant
|
||||||
|
fn limit_alphabet(
|
||||||
|
&self,
|
||||||
|
wc: &HashMap<String, u32>,
|
||||||
|
w2id: &mut HashMap<String, u32>,
|
||||||
|
id2w: &mut Vec<String>,
|
||||||
|
) {
|
||||||
|
let mut alphabet: HashMap<char, usize> = HashMap::new();
|
||||||
|
for (word, count) in wc {
|
||||||
|
for c in word.chars() {
|
||||||
|
alphabet
|
||||||
|
.entry(c)
|
||||||
|
.and_modify(|cnt| *cnt += *count as usize)
|
||||||
|
.or_insert(*count as usize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let to_remove = if let Some(limit) = self.limit_alphabet {
|
||||||
|
alphabet.len() - limit
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut kept = alphabet.iter().collect::<Vec<_>>();
|
||||||
|
kept.sort_unstable_by_key(|k| *k.1);
|
||||||
|
|
||||||
|
// Remove the unwanted chars
|
||||||
|
if to_remove > 0 {
|
||||||
|
kept.drain(..to_remove);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep the initial alphabet
|
||||||
|
kept.sort_unstable_by_key(|k| (*k.0) as u32);
|
||||||
|
kept.into_iter().for_each(|(c, _)| {
|
||||||
|
let s = c.to_string();
|
||||||
|
if !w2id.contains_key(&s) {
|
||||||
|
id2w.push(s.clone());
|
||||||
|
w2id.insert(s, (id2w.len() - 1) as u32);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Trainer for BpeTrainer {
|
impl Trainer for BpeTrainer {
|
||||||
@ -57,32 +145,31 @@ impl Trainer for BpeTrainer {
|
|||||||
let mut word_to_id: HashMap<String, u32> = HashMap::new();
|
let mut word_to_id: HashMap<String, u32> = HashMap::new();
|
||||||
let mut id_to_word: Vec<String> = vec![];
|
let mut id_to_word: Vec<String> = vec![];
|
||||||
|
|
||||||
let progress = if self.show_progress {
|
let progress = self.setup_progress();
|
||||||
let p = ProgressBar::new(word_counts.len() as u64);
|
|
||||||
p.set_style(
|
|
||||||
ProgressStyle::default_bar()
|
|
||||||
.template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"),
|
|
||||||
);
|
|
||||||
p.set_message("Tokenize words");
|
|
||||||
Some(p)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// 1. Tokenize words
|
// 1. Add all special tokens to the vocabulary
|
||||||
//
|
//
|
||||||
|
self.add_special_tokens(&mut word_to_id, &mut id_to_word);
|
||||||
|
|
||||||
|
//
|
||||||
|
// 2. Limit the initial alphabet if relevant
|
||||||
|
//
|
||||||
|
self.limit_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
|
||||||
|
|
||||||
|
//
|
||||||
|
// 3. Tokenize words
|
||||||
|
//
|
||||||
|
self.update_progress(&progress, word_counts.len(), "Tokenize words");
|
||||||
for (word, count) in &word_counts {
|
for (word, count) in &word_counts {
|
||||||
let mut current_word = Word::new();
|
let mut current_word = Word::new();
|
||||||
counts.push(*count as i32);
|
counts.push(*count as i32);
|
||||||
|
|
||||||
for c in word.chars() {
|
for c in word.chars() {
|
||||||
let s = c.to_string();
|
let s = c.to_string();
|
||||||
if !word_to_id.contains_key(&s) {
|
if let Some(id) = word_to_id.get(&s) {
|
||||||
id_to_word.push(s.clone());
|
current_word.add(*id);
|
||||||
word_to_id.insert(s.clone(), (id_to_word.len() - 1) as u32);
|
|
||||||
}
|
}
|
||||||
current_word.add(word_to_id[&s]);
|
|
||||||
}
|
}
|
||||||
words.push(current_word);
|
words.push(current_word);
|
||||||
|
|
||||||
@ -90,15 +177,12 @@ impl Trainer for BpeTrainer {
|
|||||||
p.inc(1);
|
p.inc(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.finalize_progress(&progress, words.len());
|
||||||
|
|
||||||
//
|
//
|
||||||
// 2. Count pairs in words
|
// 4. Count pairs in words
|
||||||
//
|
//
|
||||||
if let Some(p) = &progress {
|
self.update_progress(&progress, words.len(), "Count pairs");
|
||||||
p.set_message("Count pairs");
|
|
||||||
p.set_length(words.len() as u64);
|
|
||||||
p.reset();
|
|
||||||
}
|
|
||||||
let mut pair_counts: HashMap<Pair, (i32, Pair)> = HashMap::new();
|
let mut pair_counts: HashMap<Pair, (i32, Pair)> = HashMap::new();
|
||||||
let mut where_to_update: HashMap<Pair, HashSet<usize>> = HashMap::new();
|
let mut where_to_update: HashMap<Pair, HashSet<usize>> = HashMap::new();
|
||||||
for (index, word) in words.iter().enumerate() {
|
for (index, word) in words.iter().enumerate() {
|
||||||
@ -128,15 +212,12 @@ impl Trainer for BpeTrainer {
|
|||||||
p.inc(1);
|
p.inc(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.finalize_progress(&progress, words.len());
|
||||||
|
|
||||||
//
|
//
|
||||||
// 3. Do merges
|
// 5. Do merges
|
||||||
//
|
//
|
||||||
if let Some(p) = &progress {
|
self.update_progress(&progress, self.vocab_size, "Compute merges");
|
||||||
p.set_message("Compute merges");
|
|
||||||
p.set_length(self.vocab_size as u64);
|
|
||||||
p.reset();
|
|
||||||
}
|
|
||||||
let mut merges: Vec<(Pair, u32)> = vec![];
|
let mut merges: Vec<(Pair, u32)> = vec![];
|
||||||
loop {
|
loop {
|
||||||
// Stop as soon as we have a big enough vocabulary
|
// Stop as soon as we have a big enough vocabulary
|
||||||
@ -206,6 +287,7 @@ impl Trainer for BpeTrainer {
|
|||||||
p.inc(1);
|
p.inc(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.finalize_progress(&progress, merges.len());
|
||||||
|
|
||||||
Ok(Box::new(BPE::new(
|
Ok(Box::new(BPE::new(
|
||||||
word_to_id.clone(),
|
word_to_id.clone(),
|
||||||
|
Reference in New Issue
Block a user