Add BPE tokenization

This commit is contained in:
Anthony MOI
2019-11-17 00:27:30 -05:00
parent b2ba864248
commit 1c7dcebca7
4 changed files with 176 additions and 0 deletions

View File

@ -0,0 +1,8 @@
mod model;
mod word;
pub type Pair = (u32, u32);
// Re-export
pub use model::*;
pub use word::*;

View File

@ -0,0 +1,91 @@
use super::{Pair, Word};
use crate::tokenizer::{Model, Token};
use std::collections::HashMap;
pub struct BPE {
/// The vocabulary assigns a number to each token
vocab: HashMap<String, u32>,
/// Reversed vocabulary, to rebuild sentences
vocab_r: HashMap<u32, String>,
/// Contains the mapping between Pairs and their (rank, new_id)
merges: HashMap<Pair, (u32, u32)>,
}
impl BPE {
pub fn new(
vocab: HashMap<String, u32>,
vocab_r: HashMap<u32, String>,
merges: HashMap<Pair, (u32, u32)>,
) -> Self {
BPE {
vocab,
vocab_r,
merges,
}
}
}
impl Model for BPE {
fn tokenize(&self, sentence: Vec<String>) -> Vec<Token> {
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
for w in sentence {
let mut word = Word::new();
for c in w.chars() {
match self.vocab.get(&c.to_string()) {
// TODO: Handle UNK
None => println!("{} is an unknown character. Skip it.", c.escape_unicode()),
Some(id) => word.add(*id),
}
}
loop {
if word.get_chars().len() < 2 {
break;
}
let ((rank, new_id), pair) = word
.get_chars()
.windows(2)
.map(|window| {
let pair = (window[0], window[1]);
let rank = self
.merges
.get(&pair)
.unwrap_or(&(std::u32::MAX, std::u32::MAX));
(rank, pair)
})
.min()
.unwrap();
if *rank == std::u32::MAX {
// We are done merging this word
break;
}
// Let's merge
word.merge(pair.0, pair.1, *new_id);
}
// Offsets are word-based, we need to translate them to be sentence-based
let last_offset = encoded.last().map(|token| token.offsets.1).unwrap_or(0);
let tokens = word
.get_chars()
.iter()
.zip(word.get_offsets())
.map(|(id, offsets)| {
Token::new(
*id,
self.vocab_r[id].clone(),
(last_offset + offsets.0, last_offset + offsets.1),
)
})
.collect::<Vec<_>>();
encoded.extend(tokens);
}
encoded
}
}

View File

@ -0,0 +1,76 @@
use super::Pair;
// TODO: Add tests
pub struct Word {
chars: Vec<u32>,
sizes: Vec<usize>,
}
impl Word {
pub fn new() -> Self {
Word {
chars: vec![],
sizes: vec![],
}
}
pub fn add(&mut self, c: u32) {
self.chars.push(c);
self.sizes.push(1);
}
pub fn merge(&mut self, c1: u32, c2: u32, replacement: u32) -> Vec<(Pair, i32)> {
let mut changes: Vec<(Pair, i32)> = vec![];
let mut i = 0;
loop {
if i >= self.chars.len() {
break;
}
// Found a pair
if self.chars[i] == c1 && i + 1 < self.chars.len() && self.chars[i + 1] == c2 {
let first = self.chars[i];
let second = self.chars[i + 1];
// If there are other characters before the pair
if i > 0 {
changes.push(((self.chars[i - 1], first), -1));
changes.push(((self.chars[i - 1], replacement), 1));
}
// Remove in place
self.chars.insert(i, replacement); // Insert replacement before first char of pair
self.chars.remove(i + 1); // Remove first char of pair
self.chars.remove(i + 1); // And then the second
// Update sizes
let new_size = self.sizes[i] + self.sizes[i + 1];
self.sizes[i] = new_size;
self.sizes.remove(i + 1);
// If there are other characters after the pair
if i < self.chars.len() - 1 {
changes.push(((second, self.chars[i + 1]), -1));
changes.push(((replacement, self.chars[i + 1]), 1));
}
}
i += 1;
}
changes
}
pub fn get_chars(&self) -> &Vec<u32> {
&self.chars
}
pub fn get_offsets(&self) -> Vec<(usize, usize)> {
let mut offsets = vec![];
let mut pos = 0;
for size in &self.sizes {
offsets.push((pos, pos + size));
pos += size;
}
offsets
}
}

View File

@ -0,0 +1 @@
pub mod bpe;