mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Add BPE tokenization
This commit is contained in:
8
tokenizers/src/models/bpe/mod.rs
Normal file
8
tokenizers/src/models/bpe/mod.rs
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
mod model;
|
||||||
|
mod word;
|
||||||
|
|
||||||
|
pub type Pair = (u32, u32);
|
||||||
|
|
||||||
|
// Re-export
|
||||||
|
pub use model::*;
|
||||||
|
pub use word::*;
|
91
tokenizers/src/models/bpe/model.rs
Normal file
91
tokenizers/src/models/bpe/model.rs
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
use super::{Pair, Word};
|
||||||
|
use crate::tokenizer::{Model, Token};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
pub struct BPE {
|
||||||
|
/// The vocabulary assigns a number to each token
|
||||||
|
vocab: HashMap<String, u32>,
|
||||||
|
/// Reversed vocabulary, to rebuild sentences
|
||||||
|
vocab_r: HashMap<u32, String>,
|
||||||
|
/// Contains the mapping between Pairs and their (rank, new_id)
|
||||||
|
merges: HashMap<Pair, (u32, u32)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BPE {
|
||||||
|
pub fn new(
|
||||||
|
vocab: HashMap<String, u32>,
|
||||||
|
vocab_r: HashMap<u32, String>,
|
||||||
|
merges: HashMap<Pair, (u32, u32)>,
|
||||||
|
) -> Self {
|
||||||
|
BPE {
|
||||||
|
vocab,
|
||||||
|
vocab_r,
|
||||||
|
merges,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model for BPE {
|
||||||
|
fn tokenize(&self, sentence: Vec<String>) -> Vec<Token> {
|
||||||
|
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
|
||||||
|
|
||||||
|
for w in sentence {
|
||||||
|
let mut word = Word::new();
|
||||||
|
for c in w.chars() {
|
||||||
|
match self.vocab.get(&c.to_string()) {
|
||||||
|
// TODO: Handle UNK
|
||||||
|
None => println!("{} is an unknown character. Skip it.", c.escape_unicode()),
|
||||||
|
Some(id) => word.add(*id),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if word.get_chars().len() < 2 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ((rank, new_id), pair) = word
|
||||||
|
.get_chars()
|
||||||
|
.windows(2)
|
||||||
|
.map(|window| {
|
||||||
|
let pair = (window[0], window[1]);
|
||||||
|
let rank = self
|
||||||
|
.merges
|
||||||
|
.get(&pair)
|
||||||
|
.unwrap_or(&(std::u32::MAX, std::u32::MAX));
|
||||||
|
(rank, pair)
|
||||||
|
})
|
||||||
|
.min()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
if *rank == std::u32::MAX {
|
||||||
|
// We are done merging this word
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let's merge
|
||||||
|
word.merge(pair.0, pair.1, *new_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Offsets are word-based, we need to translate them to be sentence-based
|
||||||
|
let last_offset = encoded.last().map(|token| token.offsets.1).unwrap_or(0);
|
||||||
|
|
||||||
|
let tokens = word
|
||||||
|
.get_chars()
|
||||||
|
.iter()
|
||||||
|
.zip(word.get_offsets())
|
||||||
|
.map(|(id, offsets)| {
|
||||||
|
Token::new(
|
||||||
|
*id,
|
||||||
|
self.vocab_r[id].clone(),
|
||||||
|
(last_offset + offsets.0, last_offset + offsets.1),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
encoded.extend(tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded
|
||||||
|
}
|
||||||
|
}
|
76
tokenizers/src/models/bpe/word.rs
Normal file
76
tokenizers/src/models/bpe/word.rs
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
use super::Pair;
|
||||||
|
|
||||||
|
// TODO: Add tests
|
||||||
|
pub struct Word {
|
||||||
|
chars: Vec<u32>,
|
||||||
|
sizes: Vec<usize>,
|
||||||
|
}
|
||||||
|
impl Word {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Word {
|
||||||
|
chars: vec![],
|
||||||
|
sizes: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add(&mut self, c: u32) {
|
||||||
|
self.chars.push(c);
|
||||||
|
self.sizes.push(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn merge(&mut self, c1: u32, c2: u32, replacement: u32) -> Vec<(Pair, i32)> {
|
||||||
|
let mut changes: Vec<(Pair, i32)> = vec![];
|
||||||
|
let mut i = 0;
|
||||||
|
loop {
|
||||||
|
if i >= self.chars.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Found a pair
|
||||||
|
if self.chars[i] == c1 && i + 1 < self.chars.len() && self.chars[i + 1] == c2 {
|
||||||
|
let first = self.chars[i];
|
||||||
|
let second = self.chars[i + 1];
|
||||||
|
|
||||||
|
// If there are other characters before the pair
|
||||||
|
if i > 0 {
|
||||||
|
changes.push(((self.chars[i - 1], first), -1));
|
||||||
|
changes.push(((self.chars[i - 1], replacement), 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove in place
|
||||||
|
self.chars.insert(i, replacement); // Insert replacement before first char of pair
|
||||||
|
self.chars.remove(i + 1); // Remove first char of pair
|
||||||
|
self.chars.remove(i + 1); // And then the second
|
||||||
|
|
||||||
|
// Update sizes
|
||||||
|
let new_size = self.sizes[i] + self.sizes[i + 1];
|
||||||
|
self.sizes[i] = new_size;
|
||||||
|
self.sizes.remove(i + 1);
|
||||||
|
|
||||||
|
// If there are other characters after the pair
|
||||||
|
if i < self.chars.len() - 1 {
|
||||||
|
changes.push(((second, self.chars[i + 1]), -1));
|
||||||
|
changes.push(((replacement, self.chars[i + 1]), 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
changes
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_chars(&self) -> &Vec<u32> {
|
||||||
|
&self.chars
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_offsets(&self) -> Vec<(usize, usize)> {
|
||||||
|
let mut offsets = vec![];
|
||||||
|
let mut pos = 0;
|
||||||
|
for size in &self.sizes {
|
||||||
|
offsets.push((pos, pos + size));
|
||||||
|
pos += size;
|
||||||
|
}
|
||||||
|
offsets
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1 @@
|
|||||||
|
pub mod bpe;
|
||||||
|
Reference in New Issue
Block a user