diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs new file mode 100644 index 00000000..54e0e644 --- /dev/null +++ b/tokenizers/src/models/bpe/mod.rs @@ -0,0 +1,8 @@ +mod model; +mod word; + +pub type Pair = (u32, u32); + +// Re-export +pub use model::*; +pub use word::*; diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs new file mode 100644 index 00000000..009349e0 --- /dev/null +++ b/tokenizers/src/models/bpe/model.rs @@ -0,0 +1,91 @@ +use super::{Pair, Word}; +use crate::tokenizer::{Model, Token}; +use std::collections::HashMap; + +pub struct BPE { + /// The vocabulary assigns a number to each token + vocab: HashMap, + /// Reversed vocabulary, to rebuild sentences + vocab_r: HashMap, + /// Contains the mapping between Pairs and their (rank, new_id) + merges: HashMap, +} + +impl BPE { + pub fn new( + vocab: HashMap, + vocab_r: HashMap, + merges: HashMap, + ) -> Self { + BPE { + vocab, + vocab_r, + merges, + } + } +} + +impl Model for BPE { + fn tokenize(&self, sentence: Vec) -> Vec { + let mut encoded: Vec = Vec::with_capacity(sentence.len()); + + for w in sentence { + let mut word = Word::new(); + for c in w.chars() { + match self.vocab.get(&c.to_string()) { + // TODO: Handle UNK + None => println!("{} is an unknown character. Skip it.", c.escape_unicode()), + Some(id) => word.add(*id), + } + } + + loop { + if word.get_chars().len() < 2 { + break; + } + + let ((rank, new_id), pair) = word + .get_chars() + .windows(2) + .map(|window| { + let pair = (window[0], window[1]); + let rank = self + .merges + .get(&pair) + .unwrap_or(&(std::u32::MAX, std::u32::MAX)); + (rank, pair) + }) + .min() + .unwrap(); + + if *rank == std::u32::MAX { + // We are done merging this word + break; + } + + // Let's merge + word.merge(pair.0, pair.1, *new_id); + } + + // Offsets are word-based, we need to translate them to be sentence-based + let last_offset = encoded.last().map(|token| token.offsets.1).unwrap_or(0); + + let tokens = word + .get_chars() + .iter() + .zip(word.get_offsets()) + .map(|(id, offsets)| { + Token::new( + *id, + self.vocab_r[id].clone(), + (last_offset + offsets.0, last_offset + offsets.1), + ) + }) + .collect::>(); + + encoded.extend(tokens); + } + + encoded + } +} diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs new file mode 100644 index 00000000..9a5c0f23 --- /dev/null +++ b/tokenizers/src/models/bpe/word.rs @@ -0,0 +1,76 @@ +use super::Pair; + +// TODO: Add tests +pub struct Word { + chars: Vec, + sizes: Vec, +} +impl Word { + pub fn new() -> Self { + Word { + chars: vec![], + sizes: vec![], + } + } + + pub fn add(&mut self, c: u32) { + self.chars.push(c); + self.sizes.push(1); + } + + pub fn merge(&mut self, c1: u32, c2: u32, replacement: u32) -> Vec<(Pair, i32)> { + let mut changes: Vec<(Pair, i32)> = vec![]; + let mut i = 0; + loop { + if i >= self.chars.len() { + break; + } + + // Found a pair + if self.chars[i] == c1 && i + 1 < self.chars.len() && self.chars[i + 1] == c2 { + let first = self.chars[i]; + let second = self.chars[i + 1]; + + // If there are other characters before the pair + if i > 0 { + changes.push(((self.chars[i - 1], first), -1)); + changes.push(((self.chars[i - 1], replacement), 1)); + } + + // Remove in place + self.chars.insert(i, replacement); // Insert replacement before first char of pair + self.chars.remove(i + 1); // Remove first char of pair + self.chars.remove(i + 1); // And then the second + + // Update sizes + let new_size = self.sizes[i] + self.sizes[i + 1]; + self.sizes[i] = new_size; + self.sizes.remove(i + 1); + + // If there are other characters after the pair + if i < self.chars.len() - 1 { + changes.push(((second, self.chars[i + 1]), -1)); + changes.push(((replacement, self.chars[i + 1]), 1)); + } + } + + i += 1; + } + + changes + } + + pub fn get_chars(&self) -> &Vec { + &self.chars + } + + pub fn get_offsets(&self) -> Vec<(usize, usize)> { + let mut offsets = vec![]; + let mut pos = 0; + for size in &self.sizes { + offsets.push((pos, pos + size)); + pos += size; + } + offsets + } +} diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index e69de29b..de7237b0 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -0,0 +1 @@ +pub mod bpe;