set capacity on BPE cache, change Mutex to RwLock, create BpeBuilder (#24)

* set capacity on BPE cache, create BpeBuilder

* add doc comment

* switch from Mutex to RwLock

* vocab_and_merges
This commit is contained in:
Evan Pete Walsh
2020-01-02 09:26:50 -08:00
committed by GitHub
parent e3cf6a7b00
commit 8ae0f2efdb
3 changed files with 148 additions and 64 deletions

View File

@ -1,18 +1,34 @@
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Mutex;
use std::sync::RwLock;
/// The default capacity for a new `Cache`.
pub static DEFAULT_CACHE_CAPACITY: usize = 10_000;
/// Provides a simple multithread cache that will try to retrieve values
/// but won't block if someone else is already using it.
/// The goal is clearly not the accuracy of the content, both get and set
/// are not guaranteed to actually get or set.
#[derive(Default)]
pub struct Cache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
map: Mutex<HashMap<K, V>>,
map: RwLock<HashMap<K, V>>,
pub capacity: usize,
}
impl<K, V> Default for Cache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
fn default() -> Self {
Self {
map: RwLock::new(HashMap::with_capacity(DEFAULT_CACHE_CAPACITY)),
capacity: DEFAULT_CACHE_CAPACITY,
}
}
}
impl<K, V> Cache<K, V>
@ -20,9 +36,21 @@ where
K: Eq + Hash + Clone,
V: Clone,
{
pub fn new() -> Self {
Cache {
map: Mutex::new(HashMap::new()),
/// Create new `Cache` with the given capacity.
pub fn new(capacity: usize) -> Self {
let map = RwLock::new(HashMap::with_capacity(capacity));
Cache { map, capacity }
}
/// Create a fresh `Cache` with the same configuration.
pub fn fresh(&self) -> Self {
Self::new(self.capacity)
}
/// Try clearing the cache.
pub fn try_clear(&self) {
if let Ok(ref mut cache) = self.map.try_write() {
cache.clear();
}
}
@ -30,8 +58,7 @@ where
where
I: Iterator<Item = K>,
{
let mut lock = self.map.try_lock();
if let Ok(ref mut cache) = lock {
if let Ok(ref mut cache) = self.map.try_read() {
Some(keys_iter.map(|k| cache.get(&k).cloned()).collect())
} else {
None
@ -43,9 +70,12 @@ where
I: Iterator<Item = K>,
J: Iterator<Item = Option<V>>,
{
let mut lock = self.map.try_lock();
if let Ok(ref mut cache) = lock {
if let Ok(ref mut cache) = self.map.try_write() {
for (key, value) in keys_iter.zip(values_iter).filter(|(_, v)| v.is_some()) {
// If already at capacity, don't add any more values.
if cache.len() >= self.capacity {
break;
}
cache.insert(key, value.unwrap());
}
}

View File

@ -10,6 +10,92 @@ use std::{
path::{Path, PathBuf},
};
#[derive(Default)]
struct Config {
vocab: Option<HashMap<String, u32>>,
vocab_r: Option<HashMap<u32, String>>,
merges: Option<HashMap<Pair, (u32, u32)>>,
cache_capacity: Option<usize>,
dropout: Option<f32>,
unk_token: Option<u32>,
}
/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
#[derive(Default)]
pub struct BpeBuilder {
config: Config,
}
impl BpeBuilder {
/// Constructs a new `BpeBuilder`.
pub fn new() -> Self {
Self::default()
}
/// Set the vocab (token -> ID) and merges mappings.
pub fn vocab_and_merges(
mut self,
vocab: HashMap<String, u32>,
merges: HashMap<Pair, (u32, u32)>,
) -> Self {
self.config.vocab = Some(vocab);
self.config.merges = Some(merges);
self
}
/// Set the cache's capacity.
pub fn cache_capacity(mut self, capacity: usize) -> Self {
self.config.cache_capacity = Some(capacity);
self
}
/// Use [dropout](https://arxiv.org/abs/1910.13267) with the model.
pub fn dropout(mut self, dropout: f32) -> Self {
self.config.dropout = Some(dropout);
self
}
/// Set the `UNK` token for the vocab.
pub fn unk_token(mut self, unk_token: u32) -> Self {
self.config.unk_token = Some(unk_token);
self
}
/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
pub fn build(self) -> Result<BPE> {
// Validate dropout.
if let Some(p) = self.config.dropout {
if p < 0.0 || p > 1.0 {
return Err(Error::InvalidDropout.into());
}
}
let vocab = self.config.vocab.unwrap_or_else(HashMap::new);
let vocab_r = match self.config.vocab_r {
Some(vocab_r) => vocab_r,
None => vocab
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect(),
};
let merges = self.config.merges.unwrap_or_else(HashMap::new);
let cache = match self.config.cache_capacity {
Some(capacity) => Cache::new(capacity),
None => Cache::default(),
};
Ok(BPE {
vocab,
vocab_r,
merges,
cache,
dropout: self.config.dropout,
unk_token: self.config.unk_token,
})
}
}
/// A Byte Pair Encoding model.
pub struct BPE {
/// The vocabulary assigns a number to each token.
vocab: HashMap<String, u32>,
@ -28,24 +114,19 @@ pub struct BPE {
impl Default for BPE {
fn default() -> Self {
Self {
vocab: HashMap::new(),
vocab_r: HashMap::new(),
merges: HashMap::new(),
cache: Cache::new(),
dropout: None,
unk_token: None,
}
Self::builder().build().unwrap()
}
}
impl Clone for BPE {
// `Clone` can't be derive because it's not implemented for `Cache`.
// To keep things simple when we clone, the new BPE will start with a fresh cache.
fn clone(&self) -> Self {
Self {
vocab: self.vocab.clone(),
vocab_r: self.vocab_r.clone(),
merges: self.merges.clone(),
cache: Cache::new(),
cache: self.cache.fresh(),
dropout: self.dropout,
unk_token: self.unk_token,
}
@ -53,39 +134,20 @@ impl Clone for BPE {
}
impl BPE {
pub fn new(
vocab: HashMap<String, u32>,
vocab_r: HashMap<u32, String>,
merges: HashMap<Pair, (u32, u32)>,
) -> Self {
BPE {
vocab,
vocab_r,
merges,
..Default::default()
}
/// Initialize a `BpeBuilder`.
pub fn builder() -> BpeBuilder {
BpeBuilder::new()
}
/// Initialize a BPE model with [dropout](https://arxiv.org/abs/1910.13267).
pub fn with_dropout(
vocab: HashMap<String, u32>,
vocab_r: HashMap<u32, String>,
merges: HashMap<Pair, (u32, u32)>,
dropout: f32,
) -> Result<Self> {
if dropout < 0.0 || dropout > 1.0 {
Err(Error::InvalidDropout.into())
} else {
Ok(BPE {
vocab,
vocab_r,
merges,
dropout: if dropout == 0.0 { None } else { Some(dropout) },
..Default::default()
})
}
/// Create a new BPE model with the given vocab and merges.
pub fn new(vocab: HashMap<String, u32>, merges: HashMap<Pair, (u32, u32)>) -> Self {
Self::builder()
.vocab_and_merges(vocab, merges)
.build()
.unwrap()
}
/// Initialize a BPE model from vocab and merges file.
pub fn from_files(vocab: &str, merges: &str) -> Result<Self> {
// Read vocab.json
let vocab_file = File::open(vocab)?;
@ -138,12 +200,12 @@ impl BPE {
merges.insert(pair, (rank as u32, *new_id));
}
Ok(BPE {
vocab: vocab.clone(),
vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(),
merges,
..Default::default()
})
Ok(BPE::new(vocab, merges))
}
/// Try resetting the cache. This fails if a lock can't be acquired.
pub fn try_clear_cache(&self) {
self.cache.try_clear()
}
fn merge_word(&self, w: &str) -> Word {
@ -341,10 +403,6 @@ mod tests {
.iter()
.cloned()
.collect();
let vocab_r: HashMap<u32, String> = vocab
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect();
let merges: HashMap<Pair, (u32, u32)> = [
((vocab["r"], vocab["e"]), (1u32, vocab["re"])), // 'r-e' -> 're'
((vocab["a"], vocab["t"]), (2u32, vocab["at"])), // 'a-t' -> 'at'
@ -358,7 +416,7 @@ mod tests {
.iter()
.cloned()
.collect();
let mut bpe = BPE::new(vocab, vocab_r, merges);
let mut bpe = BPE::new(vocab, merges);
let sentence: Vec<(String, Offsets)> = vec![("unrelated".into(), (0, 9))];

View File

@ -293,11 +293,7 @@ impl Trainer for BpeTrainer {
self.finalize_progress(&progress, merges.len());
Ok(Box::new(BPE::new(
word_to_id.clone(),
word_to_id
.into_iter()
.map(|(token, id)| (id, token))
.collect(),
word_to_id,
merges
.into_iter()
.enumerate()