diff --git a/tokenizers/src/models/bpe/cache.rs b/tokenizers/src/models/bpe/cache.rs index 5f657ac3..1e3adf17 100644 --- a/tokenizers/src/models/bpe/cache.rs +++ b/tokenizers/src/models/bpe/cache.rs @@ -1,18 +1,34 @@ use std::collections::HashMap; use std::hash::Hash; -use std::sync::Mutex; +use std::sync::RwLock; + +/// The default capacity for a new `Cache`. +pub static DEFAULT_CACHE_CAPACITY: usize = 10_000; /// Provides a simple multithread cache that will try to retrieve values /// but won't block if someone else is already using it. /// The goal is clearly not the accuracy of the content, both get and set /// are not guaranteed to actually get or set. -#[derive(Default)] pub struct Cache where K: Eq + Hash + Clone, V: Clone, { - map: Mutex>, + map: RwLock>, + pub capacity: usize, +} + +impl Default for Cache +where + K: Eq + Hash + Clone, + V: Clone, +{ + fn default() -> Self { + Self { + map: RwLock::new(HashMap::with_capacity(DEFAULT_CACHE_CAPACITY)), + capacity: DEFAULT_CACHE_CAPACITY, + } + } } impl Cache @@ -20,9 +36,21 @@ where K: Eq + Hash + Clone, V: Clone, { - pub fn new() -> Self { - Cache { - map: Mutex::new(HashMap::new()), + /// Create new `Cache` with the given capacity. + pub fn new(capacity: usize) -> Self { + let map = RwLock::new(HashMap::with_capacity(capacity)); + Cache { map, capacity } + } + + /// Create a fresh `Cache` with the same configuration. + pub fn fresh(&self) -> Self { + Self::new(self.capacity) + } + + /// Try clearing the cache. + pub fn try_clear(&self) { + if let Ok(ref mut cache) = self.map.try_write() { + cache.clear(); } } @@ -30,8 +58,7 @@ where where I: Iterator, { - let mut lock = self.map.try_lock(); - if let Ok(ref mut cache) = lock { + if let Ok(ref mut cache) = self.map.try_read() { Some(keys_iter.map(|k| cache.get(&k).cloned()).collect()) } else { None @@ -43,9 +70,12 @@ where I: Iterator, J: Iterator>, { - let mut lock = self.map.try_lock(); - if let Ok(ref mut cache) = lock { + if let Ok(ref mut cache) = self.map.try_write() { for (key, value) in keys_iter.zip(values_iter).filter(|(_, v)| v.is_some()) { + // If already at capacity, don't add any more values. + if cache.len() >= self.capacity { + break; + } cache.insert(key, value.unwrap()); } } diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 9117e3c9..8298ad49 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -10,6 +10,92 @@ use std::{ path::{Path, PathBuf}, }; +#[derive(Default)] +struct Config { + vocab: Option>, + vocab_r: Option>, + merges: Option>, + cache_capacity: Option, + dropout: Option, + unk_token: Option, +} + +/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration. +#[derive(Default)] +pub struct BpeBuilder { + config: Config, +} + +impl BpeBuilder { + /// Constructs a new `BpeBuilder`. + pub fn new() -> Self { + Self::default() + } + + /// Set the vocab (token -> ID) and merges mappings. + pub fn vocab_and_merges( + mut self, + vocab: HashMap, + merges: HashMap, + ) -> Self { + self.config.vocab = Some(vocab); + self.config.merges = Some(merges); + self + } + + /// Set the cache's capacity. + pub fn cache_capacity(mut self, capacity: usize) -> Self { + self.config.cache_capacity = Some(capacity); + self + } + + /// Use [dropout](https://arxiv.org/abs/1910.13267) with the model. + pub fn dropout(mut self, dropout: f32) -> Self { + self.config.dropout = Some(dropout); + self + } + + /// Set the `UNK` token for the vocab. + pub fn unk_token(mut self, unk_token: u32) -> Self { + self.config.unk_token = Some(unk_token); + self + } + + /// Returns a `BPE` model that uses the `BpeBuilder`'s configuration. + pub fn build(self) -> Result { + // Validate dropout. + if let Some(p) = self.config.dropout { + if p < 0.0 || p > 1.0 { + return Err(Error::InvalidDropout.into()); + } + } + + let vocab = self.config.vocab.unwrap_or_else(HashMap::new); + let vocab_r = match self.config.vocab_r { + Some(vocab_r) => vocab_r, + None => vocab + .iter() + .map(|(key, val)| (*val, key.to_owned())) + .collect(), + }; + let merges = self.config.merges.unwrap_or_else(HashMap::new); + let cache = match self.config.cache_capacity { + Some(capacity) => Cache::new(capacity), + None => Cache::default(), + }; + + Ok(BPE { + vocab, + vocab_r, + merges, + cache, + dropout: self.config.dropout, + unk_token: self.config.unk_token, + }) + } +} + +/// A Byte Pair Encoding model. pub struct BPE { /// The vocabulary assigns a number to each token. vocab: HashMap, @@ -28,24 +114,19 @@ pub struct BPE { impl Default for BPE { fn default() -> Self { - Self { - vocab: HashMap::new(), - vocab_r: HashMap::new(), - merges: HashMap::new(), - cache: Cache::new(), - dropout: None, - unk_token: None, - } + Self::builder().build().unwrap() } } impl Clone for BPE { + // `Clone` can't be derive because it's not implemented for `Cache`. + // To keep things simple when we clone, the new BPE will start with a fresh cache. fn clone(&self) -> Self { Self { vocab: self.vocab.clone(), vocab_r: self.vocab_r.clone(), merges: self.merges.clone(), - cache: Cache::new(), + cache: self.cache.fresh(), dropout: self.dropout, unk_token: self.unk_token, } @@ -53,39 +134,20 @@ impl Clone for BPE { } impl BPE { - pub fn new( - vocab: HashMap, - vocab_r: HashMap, - merges: HashMap, - ) -> Self { - BPE { - vocab, - vocab_r, - merges, - ..Default::default() - } + /// Initialize a `BpeBuilder`. + pub fn builder() -> BpeBuilder { + BpeBuilder::new() } - /// Initialize a BPE model with [dropout](https://arxiv.org/abs/1910.13267). - pub fn with_dropout( - vocab: HashMap, - vocab_r: HashMap, - merges: HashMap, - dropout: f32, - ) -> Result { - if dropout < 0.0 || dropout > 1.0 { - Err(Error::InvalidDropout.into()) - } else { - Ok(BPE { - vocab, - vocab_r, - merges, - dropout: if dropout == 0.0 { None } else { Some(dropout) }, - ..Default::default() - }) - } + /// Create a new BPE model with the given vocab and merges. + pub fn new(vocab: HashMap, merges: HashMap) -> Self { + Self::builder() + .vocab_and_merges(vocab, merges) + .build() + .unwrap() } + /// Initialize a BPE model from vocab and merges file. pub fn from_files(vocab: &str, merges: &str) -> Result { // Read vocab.json let vocab_file = File::open(vocab)?; @@ -138,12 +200,12 @@ impl BPE { merges.insert(pair, (rank as u32, *new_id)); } - Ok(BPE { - vocab: vocab.clone(), - vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(), - merges, - ..Default::default() - }) + Ok(BPE::new(vocab, merges)) + } + + /// Try resetting the cache. This fails if a lock can't be acquired. + pub fn try_clear_cache(&self) { + self.cache.try_clear() } fn merge_word(&self, w: &str) -> Word { @@ -341,10 +403,6 @@ mod tests { .iter() .cloned() .collect(); - let vocab_r: HashMap = vocab - .iter() - .map(|(key, val)| (*val, key.to_owned())) - .collect(); let merges: HashMap = [ ((vocab["r"], vocab["e"]), (1u32, vocab["re"])), // 'r-e' -> 're' ((vocab["a"], vocab["t"]), (2u32, vocab["at"])), // 'a-t' -> 'at' @@ -358,7 +416,7 @@ mod tests { .iter() .cloned() .collect(); - let mut bpe = BPE::new(vocab, vocab_r, merges); + let mut bpe = BPE::new(vocab, merges); let sentence: Vec<(String, Offsets)> = vec![("unrelated".into(), (0, 9))]; diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index ec80793e..e8987290 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -293,11 +293,7 @@ impl Trainer for BpeTrainer { self.finalize_progress(&progress, merges.len()); Ok(Box::new(BPE::new( - word_to_id.clone(), - word_to_id - .into_iter() - .map(|(token, id)| (id, token)) - .collect(), + word_to_id, merges .into_iter() .enumerate()