mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
set capacity on BPE cache, change Mutex to RwLock, create BpeBuilder (#24)
* set capacity on BPE cache, create BpeBuilder * add doc comment * switch from Mutex to RwLock * vocab_and_merges
This commit is contained in:
@ -1,18 +1,34 @@
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::RwLock;
|
||||
|
||||
/// The default capacity for a new `Cache`.
|
||||
pub static DEFAULT_CACHE_CAPACITY: usize = 10_000;
|
||||
|
||||
/// Provides a simple multithread cache that will try to retrieve values
|
||||
/// but won't block if someone else is already using it.
|
||||
/// The goal is clearly not the accuracy of the content, both get and set
|
||||
/// are not guaranteed to actually get or set.
|
||||
#[derive(Default)]
|
||||
pub struct Cache<K, V>
|
||||
where
|
||||
K: Eq + Hash + Clone,
|
||||
V: Clone,
|
||||
{
|
||||
map: Mutex<HashMap<K, V>>,
|
||||
map: RwLock<HashMap<K, V>>,
|
||||
pub capacity: usize,
|
||||
}
|
||||
|
||||
impl<K, V> Default for Cache<K, V>
|
||||
where
|
||||
K: Eq + Hash + Clone,
|
||||
V: Clone,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
map: RwLock::new(HashMap::with_capacity(DEFAULT_CACHE_CAPACITY)),
|
||||
capacity: DEFAULT_CACHE_CAPACITY,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, V> Cache<K, V>
|
||||
@ -20,9 +36,21 @@ where
|
||||
K: Eq + Hash + Clone,
|
||||
V: Clone,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
Cache {
|
||||
map: Mutex::new(HashMap::new()),
|
||||
/// Create new `Cache` with the given capacity.
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
let map = RwLock::new(HashMap::with_capacity(capacity));
|
||||
Cache { map, capacity }
|
||||
}
|
||||
|
||||
/// Create a fresh `Cache` with the same configuration.
|
||||
pub fn fresh(&self) -> Self {
|
||||
Self::new(self.capacity)
|
||||
}
|
||||
|
||||
/// Try clearing the cache.
|
||||
pub fn try_clear(&self) {
|
||||
if let Ok(ref mut cache) = self.map.try_write() {
|
||||
cache.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,8 +58,7 @@ where
|
||||
where
|
||||
I: Iterator<Item = K>,
|
||||
{
|
||||
let mut lock = self.map.try_lock();
|
||||
if let Ok(ref mut cache) = lock {
|
||||
if let Ok(ref mut cache) = self.map.try_read() {
|
||||
Some(keys_iter.map(|k| cache.get(&k).cloned()).collect())
|
||||
} else {
|
||||
None
|
||||
@ -43,9 +70,12 @@ where
|
||||
I: Iterator<Item = K>,
|
||||
J: Iterator<Item = Option<V>>,
|
||||
{
|
||||
let mut lock = self.map.try_lock();
|
||||
if let Ok(ref mut cache) = lock {
|
||||
if let Ok(ref mut cache) = self.map.try_write() {
|
||||
for (key, value) in keys_iter.zip(values_iter).filter(|(_, v)| v.is_some()) {
|
||||
// If already at capacity, don't add any more values.
|
||||
if cache.len() >= self.capacity {
|
||||
break;
|
||||
}
|
||||
cache.insert(key, value.unwrap());
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,92 @@ use std::{
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
struct Config {
|
||||
vocab: Option<HashMap<String, u32>>,
|
||||
vocab_r: Option<HashMap<u32, String>>,
|
||||
merges: Option<HashMap<Pair, (u32, u32)>>,
|
||||
cache_capacity: Option<usize>,
|
||||
dropout: Option<f32>,
|
||||
unk_token: Option<u32>,
|
||||
}
|
||||
|
||||
/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
|
||||
#[derive(Default)]
|
||||
pub struct BpeBuilder {
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl BpeBuilder {
|
||||
/// Constructs a new `BpeBuilder`.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set the vocab (token -> ID) and merges mappings.
|
||||
pub fn vocab_and_merges(
|
||||
mut self,
|
||||
vocab: HashMap<String, u32>,
|
||||
merges: HashMap<Pair, (u32, u32)>,
|
||||
) -> Self {
|
||||
self.config.vocab = Some(vocab);
|
||||
self.config.merges = Some(merges);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the cache's capacity.
|
||||
pub fn cache_capacity(mut self, capacity: usize) -> Self {
|
||||
self.config.cache_capacity = Some(capacity);
|
||||
self
|
||||
}
|
||||
|
||||
/// Use [dropout](https://arxiv.org/abs/1910.13267) with the model.
|
||||
pub fn dropout(mut self, dropout: f32) -> Self {
|
||||
self.config.dropout = Some(dropout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the `UNK` token for the vocab.
|
||||
pub fn unk_token(mut self, unk_token: u32) -> Self {
|
||||
self.config.unk_token = Some(unk_token);
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
|
||||
pub fn build(self) -> Result<BPE> {
|
||||
// Validate dropout.
|
||||
if let Some(p) = self.config.dropout {
|
||||
if p < 0.0 || p > 1.0 {
|
||||
return Err(Error::InvalidDropout.into());
|
||||
}
|
||||
}
|
||||
|
||||
let vocab = self.config.vocab.unwrap_or_else(HashMap::new);
|
||||
let vocab_r = match self.config.vocab_r {
|
||||
Some(vocab_r) => vocab_r,
|
||||
None => vocab
|
||||
.iter()
|
||||
.map(|(key, val)| (*val, key.to_owned()))
|
||||
.collect(),
|
||||
};
|
||||
let merges = self.config.merges.unwrap_or_else(HashMap::new);
|
||||
let cache = match self.config.cache_capacity {
|
||||
Some(capacity) => Cache::new(capacity),
|
||||
None => Cache::default(),
|
||||
};
|
||||
|
||||
Ok(BPE {
|
||||
vocab,
|
||||
vocab_r,
|
||||
merges,
|
||||
cache,
|
||||
dropout: self.config.dropout,
|
||||
unk_token: self.config.unk_token,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A Byte Pair Encoding model.
|
||||
pub struct BPE {
|
||||
/// The vocabulary assigns a number to each token.
|
||||
vocab: HashMap<String, u32>,
|
||||
@ -28,24 +114,19 @@ pub struct BPE {
|
||||
|
||||
impl Default for BPE {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab: HashMap::new(),
|
||||
vocab_r: HashMap::new(),
|
||||
merges: HashMap::new(),
|
||||
cache: Cache::new(),
|
||||
dropout: None,
|
||||
unk_token: None,
|
||||
}
|
||||
Self::builder().build().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for BPE {
|
||||
// `Clone` can't be derive because it's not implemented for `Cache`.
|
||||
// To keep things simple when we clone, the new BPE will start with a fresh cache.
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
vocab: self.vocab.clone(),
|
||||
vocab_r: self.vocab_r.clone(),
|
||||
merges: self.merges.clone(),
|
||||
cache: Cache::new(),
|
||||
cache: self.cache.fresh(),
|
||||
dropout: self.dropout,
|
||||
unk_token: self.unk_token,
|
||||
}
|
||||
@ -53,39 +134,20 @@ impl Clone for BPE {
|
||||
}
|
||||
|
||||
impl BPE {
|
||||
pub fn new(
|
||||
vocab: HashMap<String, u32>,
|
||||
vocab_r: HashMap<u32, String>,
|
||||
merges: HashMap<Pair, (u32, u32)>,
|
||||
) -> Self {
|
||||
BPE {
|
||||
vocab,
|
||||
vocab_r,
|
||||
merges,
|
||||
..Default::default()
|
||||
}
|
||||
/// Initialize a `BpeBuilder`.
|
||||
pub fn builder() -> BpeBuilder {
|
||||
BpeBuilder::new()
|
||||
}
|
||||
|
||||
/// Initialize a BPE model with [dropout](https://arxiv.org/abs/1910.13267).
|
||||
pub fn with_dropout(
|
||||
vocab: HashMap<String, u32>,
|
||||
vocab_r: HashMap<u32, String>,
|
||||
merges: HashMap<Pair, (u32, u32)>,
|
||||
dropout: f32,
|
||||
) -> Result<Self> {
|
||||
if dropout < 0.0 || dropout > 1.0 {
|
||||
Err(Error::InvalidDropout.into())
|
||||
} else {
|
||||
Ok(BPE {
|
||||
vocab,
|
||||
vocab_r,
|
||||
merges,
|
||||
dropout: if dropout == 0.0 { None } else { Some(dropout) },
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
/// Create a new BPE model with the given vocab and merges.
|
||||
pub fn new(vocab: HashMap<String, u32>, merges: HashMap<Pair, (u32, u32)>) -> Self {
|
||||
Self::builder()
|
||||
.vocab_and_merges(vocab, merges)
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Initialize a BPE model from vocab and merges file.
|
||||
pub fn from_files(vocab: &str, merges: &str) -> Result<Self> {
|
||||
// Read vocab.json
|
||||
let vocab_file = File::open(vocab)?;
|
||||
@ -138,12 +200,12 @@ impl BPE {
|
||||
merges.insert(pair, (rank as u32, *new_id));
|
||||
}
|
||||
|
||||
Ok(BPE {
|
||||
vocab: vocab.clone(),
|
||||
vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(),
|
||||
merges,
|
||||
..Default::default()
|
||||
})
|
||||
Ok(BPE::new(vocab, merges))
|
||||
}
|
||||
|
||||
/// Try resetting the cache. This fails if a lock can't be acquired.
|
||||
pub fn try_clear_cache(&self) {
|
||||
self.cache.try_clear()
|
||||
}
|
||||
|
||||
fn merge_word(&self, w: &str) -> Word {
|
||||
@ -341,10 +403,6 @@ mod tests {
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let vocab_r: HashMap<u32, String> = vocab
|
||||
.iter()
|
||||
.map(|(key, val)| (*val, key.to_owned()))
|
||||
.collect();
|
||||
let merges: HashMap<Pair, (u32, u32)> = [
|
||||
((vocab["r"], vocab["e"]), (1u32, vocab["re"])), // 'r-e' -> 're'
|
||||
((vocab["a"], vocab["t"]), (2u32, vocab["at"])), // 'a-t' -> 'at'
|
||||
@ -358,7 +416,7 @@ mod tests {
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let mut bpe = BPE::new(vocab, vocab_r, merges);
|
||||
let mut bpe = BPE::new(vocab, merges);
|
||||
|
||||
let sentence: Vec<(String, Offsets)> = vec![("unrelated".into(), (0, 9))];
|
||||
|
||||
|
@ -293,11 +293,7 @@ impl Trainer for BpeTrainer {
|
||||
self.finalize_progress(&progress, merges.len());
|
||||
|
||||
Ok(Box::new(BPE::new(
|
||||
word_to_id.clone(),
|
||||
word_to_id
|
||||
.into_iter()
|
||||
.map(|(token, id)| (id, token))
|
||||
.collect(),
|
||||
word_to_id,
|
||||
merges
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
|
Reference in New Issue
Block a user