mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
set capacity on BPE cache, change Mutex to RwLock, create BpeBuilder (#24)
* set capacity on BPE cache, create BpeBuilder * add doc comment * switch from Mutex to RwLock * vocab_and_merges
This commit is contained in:
@ -1,18 +1,34 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use std::sync::Mutex;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
|
/// The default capacity for a new `Cache`.
|
||||||
|
pub static DEFAULT_CACHE_CAPACITY: usize = 10_000;
|
||||||
|
|
||||||
/// Provides a simple multithread cache that will try to retrieve values
|
/// Provides a simple multithread cache that will try to retrieve values
|
||||||
/// but won't block if someone else is already using it.
|
/// but won't block if someone else is already using it.
|
||||||
/// The goal is clearly not the accuracy of the content, both get and set
|
/// The goal is clearly not the accuracy of the content, both get and set
|
||||||
/// are not guaranteed to actually get or set.
|
/// are not guaranteed to actually get or set.
|
||||||
#[derive(Default)]
|
|
||||||
pub struct Cache<K, V>
|
pub struct Cache<K, V>
|
||||||
where
|
where
|
||||||
K: Eq + Hash + Clone,
|
K: Eq + Hash + Clone,
|
||||||
V: Clone,
|
V: Clone,
|
||||||
{
|
{
|
||||||
map: Mutex<HashMap<K, V>>,
|
map: RwLock<HashMap<K, V>>,
|
||||||
|
pub capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<K, V> Default for Cache<K, V>
|
||||||
|
where
|
||||||
|
K: Eq + Hash + Clone,
|
||||||
|
V: Clone,
|
||||||
|
{
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
map: RwLock::new(HashMap::with_capacity(DEFAULT_CACHE_CAPACITY)),
|
||||||
|
capacity: DEFAULT_CACHE_CAPACITY,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<K, V> Cache<K, V>
|
impl<K, V> Cache<K, V>
|
||||||
@ -20,9 +36,21 @@ where
|
|||||||
K: Eq + Hash + Clone,
|
K: Eq + Hash + Clone,
|
||||||
V: Clone,
|
V: Clone,
|
||||||
{
|
{
|
||||||
pub fn new() -> Self {
|
/// Create new `Cache` with the given capacity.
|
||||||
Cache {
|
pub fn new(capacity: usize) -> Self {
|
||||||
map: Mutex::new(HashMap::new()),
|
let map = RwLock::new(HashMap::with_capacity(capacity));
|
||||||
|
Cache { map, capacity }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a fresh `Cache` with the same configuration.
|
||||||
|
pub fn fresh(&self) -> Self {
|
||||||
|
Self::new(self.capacity)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try clearing the cache.
|
||||||
|
pub fn try_clear(&self) {
|
||||||
|
if let Ok(ref mut cache) = self.map.try_write() {
|
||||||
|
cache.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -30,8 +58,7 @@ where
|
|||||||
where
|
where
|
||||||
I: Iterator<Item = K>,
|
I: Iterator<Item = K>,
|
||||||
{
|
{
|
||||||
let mut lock = self.map.try_lock();
|
if let Ok(ref mut cache) = self.map.try_read() {
|
||||||
if let Ok(ref mut cache) = lock {
|
|
||||||
Some(keys_iter.map(|k| cache.get(&k).cloned()).collect())
|
Some(keys_iter.map(|k| cache.get(&k).cloned()).collect())
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -43,9 +70,12 @@ where
|
|||||||
I: Iterator<Item = K>,
|
I: Iterator<Item = K>,
|
||||||
J: Iterator<Item = Option<V>>,
|
J: Iterator<Item = Option<V>>,
|
||||||
{
|
{
|
||||||
let mut lock = self.map.try_lock();
|
if let Ok(ref mut cache) = self.map.try_write() {
|
||||||
if let Ok(ref mut cache) = lock {
|
|
||||||
for (key, value) in keys_iter.zip(values_iter).filter(|(_, v)| v.is_some()) {
|
for (key, value) in keys_iter.zip(values_iter).filter(|(_, v)| v.is_some()) {
|
||||||
|
// If already at capacity, don't add any more values.
|
||||||
|
if cache.len() >= self.capacity {
|
||||||
|
break;
|
||||||
|
}
|
||||||
cache.insert(key, value.unwrap());
|
cache.insert(key, value.unwrap());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,92 @@ use std::{
|
|||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct Config {
|
||||||
|
vocab: Option<HashMap<String, u32>>,
|
||||||
|
vocab_r: Option<HashMap<u32, String>>,
|
||||||
|
merges: Option<HashMap<Pair, (u32, u32)>>,
|
||||||
|
cache_capacity: Option<usize>,
|
||||||
|
dropout: Option<f32>,
|
||||||
|
unk_token: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct BpeBuilder {
|
||||||
|
config: Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BpeBuilder {
|
||||||
|
/// Constructs a new `BpeBuilder`.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the vocab (token -> ID) and merges mappings.
|
||||||
|
pub fn vocab_and_merges(
|
||||||
|
mut self,
|
||||||
|
vocab: HashMap<String, u32>,
|
||||||
|
merges: HashMap<Pair, (u32, u32)>,
|
||||||
|
) -> Self {
|
||||||
|
self.config.vocab = Some(vocab);
|
||||||
|
self.config.merges = Some(merges);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the cache's capacity.
|
||||||
|
pub fn cache_capacity(mut self, capacity: usize) -> Self {
|
||||||
|
self.config.cache_capacity = Some(capacity);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Use [dropout](https://arxiv.org/abs/1910.13267) with the model.
|
||||||
|
pub fn dropout(mut self, dropout: f32) -> Self {
|
||||||
|
self.config.dropout = Some(dropout);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the `UNK` token for the vocab.
|
||||||
|
pub fn unk_token(mut self, unk_token: u32) -> Self {
|
||||||
|
self.config.unk_token = Some(unk_token);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
|
||||||
|
pub fn build(self) -> Result<BPE> {
|
||||||
|
// Validate dropout.
|
||||||
|
if let Some(p) = self.config.dropout {
|
||||||
|
if p < 0.0 || p > 1.0 {
|
||||||
|
return Err(Error::InvalidDropout.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let vocab = self.config.vocab.unwrap_or_else(HashMap::new);
|
||||||
|
let vocab_r = match self.config.vocab_r {
|
||||||
|
Some(vocab_r) => vocab_r,
|
||||||
|
None => vocab
|
||||||
|
.iter()
|
||||||
|
.map(|(key, val)| (*val, key.to_owned()))
|
||||||
|
.collect(),
|
||||||
|
};
|
||||||
|
let merges = self.config.merges.unwrap_or_else(HashMap::new);
|
||||||
|
let cache = match self.config.cache_capacity {
|
||||||
|
Some(capacity) => Cache::new(capacity),
|
||||||
|
None => Cache::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(BPE {
|
||||||
|
vocab,
|
||||||
|
vocab_r,
|
||||||
|
merges,
|
||||||
|
cache,
|
||||||
|
dropout: self.config.dropout,
|
||||||
|
unk_token: self.config.unk_token,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Byte Pair Encoding model.
|
||||||
pub struct BPE {
|
pub struct BPE {
|
||||||
/// The vocabulary assigns a number to each token.
|
/// The vocabulary assigns a number to each token.
|
||||||
vocab: HashMap<String, u32>,
|
vocab: HashMap<String, u32>,
|
||||||
@ -28,24 +114,19 @@ pub struct BPE {
|
|||||||
|
|
||||||
impl Default for BPE {
|
impl Default for BPE {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self::builder().build().unwrap()
|
||||||
vocab: HashMap::new(),
|
|
||||||
vocab_r: HashMap::new(),
|
|
||||||
merges: HashMap::new(),
|
|
||||||
cache: Cache::new(),
|
|
||||||
dropout: None,
|
|
||||||
unk_token: None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Clone for BPE {
|
impl Clone for BPE {
|
||||||
|
// `Clone` can't be derive because it's not implemented for `Cache`.
|
||||||
|
// To keep things simple when we clone, the new BPE will start with a fresh cache.
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
vocab: self.vocab.clone(),
|
vocab: self.vocab.clone(),
|
||||||
vocab_r: self.vocab_r.clone(),
|
vocab_r: self.vocab_r.clone(),
|
||||||
merges: self.merges.clone(),
|
merges: self.merges.clone(),
|
||||||
cache: Cache::new(),
|
cache: self.cache.fresh(),
|
||||||
dropout: self.dropout,
|
dropout: self.dropout,
|
||||||
unk_token: self.unk_token,
|
unk_token: self.unk_token,
|
||||||
}
|
}
|
||||||
@ -53,39 +134,20 @@ impl Clone for BPE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BPE {
|
impl BPE {
|
||||||
pub fn new(
|
/// Initialize a `BpeBuilder`.
|
||||||
vocab: HashMap<String, u32>,
|
pub fn builder() -> BpeBuilder {
|
||||||
vocab_r: HashMap<u32, String>,
|
BpeBuilder::new()
|
||||||
merges: HashMap<Pair, (u32, u32)>,
|
|
||||||
) -> Self {
|
|
||||||
BPE {
|
|
||||||
vocab,
|
|
||||||
vocab_r,
|
|
||||||
merges,
|
|
||||||
..Default::default()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initialize a BPE model with [dropout](https://arxiv.org/abs/1910.13267).
|
/// Create a new BPE model with the given vocab and merges.
|
||||||
pub fn with_dropout(
|
pub fn new(vocab: HashMap<String, u32>, merges: HashMap<Pair, (u32, u32)>) -> Self {
|
||||||
vocab: HashMap<String, u32>,
|
Self::builder()
|
||||||
vocab_r: HashMap<u32, String>,
|
.vocab_and_merges(vocab, merges)
|
||||||
merges: HashMap<Pair, (u32, u32)>,
|
.build()
|
||||||
dropout: f32,
|
.unwrap()
|
||||||
) -> Result<Self> {
|
|
||||||
if dropout < 0.0 || dropout > 1.0 {
|
|
||||||
Err(Error::InvalidDropout.into())
|
|
||||||
} else {
|
|
||||||
Ok(BPE {
|
|
||||||
vocab,
|
|
||||||
vocab_r,
|
|
||||||
merges,
|
|
||||||
dropout: if dropout == 0.0 { None } else { Some(dropout) },
|
|
||||||
..Default::default()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initialize a BPE model from vocab and merges file.
|
||||||
pub fn from_files(vocab: &str, merges: &str) -> Result<Self> {
|
pub fn from_files(vocab: &str, merges: &str) -> Result<Self> {
|
||||||
// Read vocab.json
|
// Read vocab.json
|
||||||
let vocab_file = File::open(vocab)?;
|
let vocab_file = File::open(vocab)?;
|
||||||
@ -138,12 +200,12 @@ impl BPE {
|
|||||||
merges.insert(pair, (rank as u32, *new_id));
|
merges.insert(pair, (rank as u32, *new_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(BPE {
|
Ok(BPE::new(vocab, merges))
|
||||||
vocab: vocab.clone(),
|
}
|
||||||
vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(),
|
|
||||||
merges,
|
/// Try resetting the cache. This fails if a lock can't be acquired.
|
||||||
..Default::default()
|
pub fn try_clear_cache(&self) {
|
||||||
})
|
self.cache.try_clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn merge_word(&self, w: &str) -> Word {
|
fn merge_word(&self, w: &str) -> Word {
|
||||||
@ -341,10 +403,6 @@ mod tests {
|
|||||||
.iter()
|
.iter()
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect();
|
.collect();
|
||||||
let vocab_r: HashMap<u32, String> = vocab
|
|
||||||
.iter()
|
|
||||||
.map(|(key, val)| (*val, key.to_owned()))
|
|
||||||
.collect();
|
|
||||||
let merges: HashMap<Pair, (u32, u32)> = [
|
let merges: HashMap<Pair, (u32, u32)> = [
|
||||||
((vocab["r"], vocab["e"]), (1u32, vocab["re"])), // 'r-e' -> 're'
|
((vocab["r"], vocab["e"]), (1u32, vocab["re"])), // 'r-e' -> 're'
|
||||||
((vocab["a"], vocab["t"]), (2u32, vocab["at"])), // 'a-t' -> 'at'
|
((vocab["a"], vocab["t"]), (2u32, vocab["at"])), // 'a-t' -> 'at'
|
||||||
@ -358,7 +416,7 @@ mod tests {
|
|||||||
.iter()
|
.iter()
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect();
|
.collect();
|
||||||
let mut bpe = BPE::new(vocab, vocab_r, merges);
|
let mut bpe = BPE::new(vocab, merges);
|
||||||
|
|
||||||
let sentence: Vec<(String, Offsets)> = vec![("unrelated".into(), (0, 9))];
|
let sentence: Vec<(String, Offsets)> = vec![("unrelated".into(), (0, 9))];
|
||||||
|
|
||||||
|
@ -293,11 +293,7 @@ impl Trainer for BpeTrainer {
|
|||||||
self.finalize_progress(&progress, merges.len());
|
self.finalize_progress(&progress, merges.len());
|
||||||
|
|
||||||
Ok(Box::new(BPE::new(
|
Ok(Box::new(BPE::new(
|
||||||
word_to_id.clone(),
|
word_to_id,
|
||||||
word_to_id
|
|
||||||
.into_iter()
|
|
||||||
.map(|(token, id)| (id, token))
|
|
||||||
.collect(),
|
|
||||||
merges
|
merges
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
|
Reference in New Issue
Block a user