make cache optional (#37)

This commit is contained in:
Evan Pete Walsh
2020-01-03 11:48:13 -08:00
committed by GitHub
parent 1f961aa310
commit 1ed2a4f59b

View File

@ -43,7 +43,7 @@ impl BpeBuilder {
self
}
/// Set the cache's capacity.
/// Set the cache's capacity. If the capacity is set to 0, no cache will be used.
pub fn cache_capacity(mut self, capacity: usize) -> Self {
self.config.cache_capacity = Some(capacity);
self
@ -80,8 +80,9 @@ impl BpeBuilder {
};
let merges = self.config.merges.unwrap_or_else(HashMap::new);
let cache = match self.config.cache_capacity {
Some(capacity) => Cache::new(capacity),
None => Cache::default(),
Some(0) => None,
Some(capacity) => Some(Cache::new(capacity)),
None => Some(Cache::default()),
};
Ok(BPE {
@ -104,7 +105,7 @@ pub struct BPE {
/// Contains the mapping between Pairs and their (rank, new_id).
merges: HashMap<Pair, (u32, u32)>,
/// Contains the cache for optimizing the encoding step.
cache: Cache<String, Word>,
cache: Option<Cache<String, Word>>,
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
/// perform no merges, so the result will just be characters.
dropout: Option<f32>,
@ -122,11 +123,15 @@ impl Clone for BPE {
// `Clone` can't be derive because it's not implemented for `Cache`.
// To keep things simple when we clone, the new BPE will start with a fresh cache.
fn clone(&self) -> Self {
let fresh_cache = match self.cache {
Some(ref cache) => Some(cache.fresh()),
None => None,
};
Self {
vocab: self.vocab.clone(),
vocab_r: self.vocab_r.clone(),
merges: self.merges.clone(),
cache: self.cache.fresh(),
cache: fresh_cache,
dropout: self.dropout,
unk_token: self.unk_token,
}
@ -205,7 +210,9 @@ impl BPE {
/// Reset the cache.
pub fn clear_cache(&self) {
self.cache.clear()
if let Some(ref cache) = self.cache {
cache.clear()
}
}
fn merge_word(&self, w: &str) -> Word {
@ -289,10 +296,11 @@ impl Model for BPE {
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
let mut cached_words = match self.dropout {
None => self
.cache
.get_values(sentence.iter().map(|(s, _)| s.clone())),
Some(_) => None, // If using dropout we don't want to use a cached.
None => match self.cache {
Some(ref cache) => cache.get_values(sentence.iter().map(|(s, _)| s.clone())),
None => None,
},
Some(_) => None, // If using dropout we don't want to use the cache.
};
for (i, (w, initial_offsets)) in sentence.iter().enumerate() {
@ -323,7 +331,10 @@ impl Model for BPE {
// Also update cache
if let Some(cache) = cached_words {
let keys_iter = sentence.into_iter().map(|(s, _)| s);
self.cache.set_values(keys_iter, cache.into_iter());
self.cache
.as_ref()
.unwrap()
.set_values(keys_iter, cache.into_iter());
}
Ok(encoded)