Merge pull request #23 from huggingface/cache

Avoid creating unnecessary vectors when accessing cache
This commit is contained in:
MOI Anthony
2020-01-01 14:47:28 -05:00
committed by GitHub
2 changed files with 18 additions and 21 deletions

View File

@ -26,20 +26,27 @@ where
} }
} }
pub fn get_values(&self, keys: &[K]) -> Vec<Option<V>> { pub fn get_values<I>(&self, keys_iter: I) -> Option<Vec<Option<V>>>
where
I: Iterator<Item = K>,
{
let mut lock = self.map.try_lock(); let mut lock = self.map.try_lock();
if let Ok(ref mut cache) = lock { if let Ok(ref mut cache) = lock {
keys.iter().map(|k| cache.get(k).cloned()).collect() Some(keys_iter.map(|k| cache.get(&k).cloned()).collect())
} else { } else {
keys.iter().map(|_| None).collect() None
} }
} }
pub fn set_values(&self, keys: Vec<K>, values: Vec<V>) { pub fn set_values<I, J>(&self, keys_iter: I, values_iter: J)
where
I: Iterator<Item = K>,
J: Iterator<Item = Option<V>>,
{
let mut lock = self.map.try_lock(); let mut lock = self.map.try_lock();
if let Ok(ref mut cache) = lock { if let Ok(ref mut cache) = lock {
for (key, value) in keys.into_iter().zip(values) { for (key, value) in keys_iter.zip(values_iter).filter(|(_, v)| v.is_some()) {
cache.insert(key, value); cache.insert(key, value.unwrap());
} }
} }
} }

View File

@ -212,14 +212,9 @@ impl Model for BPE {
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len()); let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
let mut cached_words = match self.dropout { let mut cached_words = match self.dropout {
None => Some( None => self
self.cache.get_values( .cache
&sentence .get_values(sentence.iter().map(|(s, _)| s.clone())),
.iter()
.map(|(s, _)| s.to_owned())
.collect::<Vec<_>>(),
),
),
Some(_) => None, // If using dropout we don't want to use a cached. Some(_) => None, // If using dropout we don't want to use a cached.
}; };
@ -250,13 +245,8 @@ impl Model for BPE {
// Also update cache // Also update cache
if let Some(cache) = cached_words { if let Some(cache) = cached_words {
let (keys, values) = sentence let keys_iter = sentence.into_iter().map(|(s, _)| s);
.into_iter() self.cache.set_values(keys_iter, cache.into_iter());
.zip(cache)
.filter(|(_, v)| v.is_some())
.map(|(k, v)| (k.0, v.unwrap()))
.unzip::<_, _, Vec<String>, Vec<Word>>();
self.cache.set_values(keys, values);
} }
Ok(encoded) Ok(encoded)