Merge pull request #23 from huggingface/cache

Avoid creating unnecessary vectors when accessing cache
This commit is contained in:
MOI Anthony
2020-01-01 14:47:28 -05:00
committed by GitHub
2 changed files with 18 additions and 21 deletions

View File

@ -26,20 +26,27 @@ where
}
}
pub fn get_values(&self, keys: &[K]) -> Vec<Option<V>> {
pub fn get_values<I>(&self, keys_iter: I) -> Option<Vec<Option<V>>>
where
I: Iterator<Item = K>,
{
let mut lock = self.map.try_lock();
if let Ok(ref mut cache) = lock {
keys.iter().map(|k| cache.get(k).cloned()).collect()
Some(keys_iter.map(|k| cache.get(&k).cloned()).collect())
} else {
keys.iter().map(|_| None).collect()
None
}
}
pub fn set_values(&self, keys: Vec<K>, values: Vec<V>) {
pub fn set_values<I, J>(&self, keys_iter: I, values_iter: J)
where
I: Iterator<Item = K>,
J: Iterator<Item = Option<V>>,
{
let mut lock = self.map.try_lock();
if let Ok(ref mut cache) = lock {
for (key, value) in keys.into_iter().zip(values) {
cache.insert(key, value);
for (key, value) in keys_iter.zip(values_iter).filter(|(_, v)| v.is_some()) {
cache.insert(key, value.unwrap());
}
}
}

View File

@ -212,14 +212,9 @@ impl Model for BPE {
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
let mut cached_words = match self.dropout {
None => Some(
self.cache.get_values(
&sentence
.iter()
.map(|(s, _)| s.to_owned())
.collect::<Vec<_>>(),
),
),
None => self
.cache
.get_values(sentence.iter().map(|(s, _)| s.clone())),
Some(_) => None, // If using dropout we don't want to use a cached.
};
@ -250,13 +245,8 @@ impl Model for BPE {
// Also update cache
if let Some(cache) = cached_words {
let (keys, values) = sentence
.into_iter()
.zip(cache)
.filter(|(_, v)| v.is_some())
.map(|(k, v)| (k.0, v.unwrap()))
.unzip::<_, _, Vec<String>, Vec<Word>>();
self.cache.set_values(keys, values);
let keys_iter = sentence.into_iter().map(|(s, _)| s);
self.cache.set_values(keys_iter, cache.into_iter());
}
Ok(encoded)