Merge pull request #34 from huggingface/improve-cache

avoid unnecessary write locks in the BPE cache
This commit is contained in:
MOI Anthony
2020-01-03 19:35:54 -05:00
committed by GitHub
2 changed files with 30 additions and 7 deletions

View File

@ -65,6 +65,19 @@ where
I: Iterator<Item = K>,
J: Iterator<Item = Option<V>>,
{
// Before trying to acquire a write lock, we check if we are already at
// capacity with a read handler.
if let Ok(ref mut cache) = self.map.try_read() {
if cache.len() >= self.capacity {
// At capacity, so do nothing.
return;
}
} else {
// If we couldn't acquire a read handle then we probably won't be able to acquire
// a write handle one quadrillionth of a second later.
return;
}
// Not at capacity, so try acquiring a write handle.
if let Ok(ref mut cache) = self.map.try_write() {
for (key, value) in keys_iter.zip(values_iter).filter(|(_, v)| v.is_some()) {
// If already at capacity, don't add any more values.

View File

@ -302,6 +302,7 @@ impl Model for BPE {
},
Some(_) => None, // If using dropout we don't want to use the cache.
};
let mut should_update_cache = false;
for (i, (w, initial_offsets)) in sentence.iter().enumerate() {
let tokens = match cached_words {
@ -318,9 +319,16 @@ impl Model for BPE {
let tokens = self.word_to_tokens(&word, initial_offsets);
// Add to cache.
cache[i] = Some(word);
should_update_cache = true;
tokens
}
Some(word) => {
let tokens = self.word_to_tokens(word, initial_offsets);
// Remove this entry so we don't needlesly try to update
// it in the cache below.
cache[i] = None;
tokens
}
Some(word) => self.word_to_tokens(word, initial_offsets),
}
}
};
@ -328,13 +336,15 @@ impl Model for BPE {
encoded.extend(tokens);
}
// Also update cache
// Try updating the cache if we need to.
if let Some(cache) = cached_words {
let keys_iter = sentence.into_iter().map(|(s, _)| s);
self.cache
.as_ref()
.unwrap()
.set_values(keys_iter, cache.into_iter());
if should_update_cache {
let keys_iter = sentence.into_iter().map(|(s, _)| s);
self.cache
.as_ref()
.unwrap()
.set_values(keys_iter, cache.into_iter());
}
}
Ok(encoded)