mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add BPE tokenization
This commit is contained in:
8
tokenizers/src/models/bpe/mod.rs
Normal file
8
tokenizers/src/models/bpe/mod.rs
Normal file
@ -0,0 +1,8 @@
|
||||
mod model;
|
||||
mod word;
|
||||
|
||||
pub type Pair = (u32, u32);
|
||||
|
||||
// Re-export
|
||||
pub use model::*;
|
||||
pub use word::*;
|
91
tokenizers/src/models/bpe/model.rs
Normal file
91
tokenizers/src/models/bpe/model.rs
Normal file
@ -0,0 +1,91 @@
|
||||
use super::{Pair, Word};
|
||||
use crate::tokenizer::{Model, Token};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub struct BPE {
|
||||
/// The vocabulary assigns a number to each token
|
||||
vocab: HashMap<String, u32>,
|
||||
/// Reversed vocabulary, to rebuild sentences
|
||||
vocab_r: HashMap<u32, String>,
|
||||
/// Contains the mapping between Pairs and their (rank, new_id)
|
||||
merges: HashMap<Pair, (u32, u32)>,
|
||||
}
|
||||
|
||||
impl BPE {
|
||||
pub fn new(
|
||||
vocab: HashMap<String, u32>,
|
||||
vocab_r: HashMap<u32, String>,
|
||||
merges: HashMap<Pair, (u32, u32)>,
|
||||
) -> Self {
|
||||
BPE {
|
||||
vocab,
|
||||
vocab_r,
|
||||
merges,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Model for BPE {
|
||||
fn tokenize(&self, sentence: Vec<String>) -> Vec<Token> {
|
||||
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
|
||||
|
||||
for w in sentence {
|
||||
let mut word = Word::new();
|
||||
for c in w.chars() {
|
||||
match self.vocab.get(&c.to_string()) {
|
||||
// TODO: Handle UNK
|
||||
None => println!("{} is an unknown character. Skip it.", c.escape_unicode()),
|
||||
Some(id) => word.add(*id),
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
if word.get_chars().len() < 2 {
|
||||
break;
|
||||
}
|
||||
|
||||
let ((rank, new_id), pair) = word
|
||||
.get_chars()
|
||||
.windows(2)
|
||||
.map(|window| {
|
||||
let pair = (window[0], window[1]);
|
||||
let rank = self
|
||||
.merges
|
||||
.get(&pair)
|
||||
.unwrap_or(&(std::u32::MAX, std::u32::MAX));
|
||||
(rank, pair)
|
||||
})
|
||||
.min()
|
||||
.unwrap();
|
||||
|
||||
if *rank == std::u32::MAX {
|
||||
// We are done merging this word
|
||||
break;
|
||||
}
|
||||
|
||||
// Let's merge
|
||||
word.merge(pair.0, pair.1, *new_id);
|
||||
}
|
||||
|
||||
// Offsets are word-based, we need to translate them to be sentence-based
|
||||
let last_offset = encoded.last().map(|token| token.offsets.1).unwrap_or(0);
|
||||
|
||||
let tokens = word
|
||||
.get_chars()
|
||||
.iter()
|
||||
.zip(word.get_offsets())
|
||||
.map(|(id, offsets)| {
|
||||
Token::new(
|
||||
*id,
|
||||
self.vocab_r[id].clone(),
|
||||
(last_offset + offsets.0, last_offset + offsets.1),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
encoded.extend(tokens);
|
||||
}
|
||||
|
||||
encoded
|
||||
}
|
||||
}
|
76
tokenizers/src/models/bpe/word.rs
Normal file
76
tokenizers/src/models/bpe/word.rs
Normal file
@ -0,0 +1,76 @@
|
||||
use super::Pair;
|
||||
|
||||
// TODO: Add tests
|
||||
pub struct Word {
|
||||
chars: Vec<u32>,
|
||||
sizes: Vec<usize>,
|
||||
}
|
||||
impl Word {
|
||||
pub fn new() -> Self {
|
||||
Word {
|
||||
chars: vec![],
|
||||
sizes: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&mut self, c: u32) {
|
||||
self.chars.push(c);
|
||||
self.sizes.push(1);
|
||||
}
|
||||
|
||||
pub fn merge(&mut self, c1: u32, c2: u32, replacement: u32) -> Vec<(Pair, i32)> {
|
||||
let mut changes: Vec<(Pair, i32)> = vec![];
|
||||
let mut i = 0;
|
||||
loop {
|
||||
if i >= self.chars.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Found a pair
|
||||
if self.chars[i] == c1 && i + 1 < self.chars.len() && self.chars[i + 1] == c2 {
|
||||
let first = self.chars[i];
|
||||
let second = self.chars[i + 1];
|
||||
|
||||
// If there are other characters before the pair
|
||||
if i > 0 {
|
||||
changes.push(((self.chars[i - 1], first), -1));
|
||||
changes.push(((self.chars[i - 1], replacement), 1));
|
||||
}
|
||||
|
||||
// Remove in place
|
||||
self.chars.insert(i, replacement); // Insert replacement before first char of pair
|
||||
self.chars.remove(i + 1); // Remove first char of pair
|
||||
self.chars.remove(i + 1); // And then the second
|
||||
|
||||
// Update sizes
|
||||
let new_size = self.sizes[i] + self.sizes[i + 1];
|
||||
self.sizes[i] = new_size;
|
||||
self.sizes.remove(i + 1);
|
||||
|
||||
// If there are other characters after the pair
|
||||
if i < self.chars.len() - 1 {
|
||||
changes.push(((second, self.chars[i + 1]), -1));
|
||||
changes.push(((replacement, self.chars[i + 1]), 1));
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
changes
|
||||
}
|
||||
|
||||
pub fn get_chars(&self) -> &Vec<u32> {
|
||||
&self.chars
|
||||
}
|
||||
|
||||
pub fn get_offsets(&self) -> Vec<(usize, usize)> {
|
||||
let mut offsets = vec![];
|
||||
let mut pos = 0;
|
||||
for size in &self.sizes {
|
||||
offsets.push((pos, pos + size));
|
||||
pos += size;
|
||||
}
|
||||
offsets
|
||||
}
|
||||
}
|
@ -0,0 +1 @@
|
||||
pub mod bpe;
|
||||
|
Reference in New Issue
Block a user