mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-31 12:39:21 +00:00
Merge pull request #23 from huggingface/cache
Avoid creating unnecessary vectors when accessing cache
This commit is contained in:
@ -26,20 +26,27 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_values(&self, keys: &[K]) -> Vec<Option<V>> {
|
pub fn get_values<I>(&self, keys_iter: I) -> Option<Vec<Option<V>>>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = K>,
|
||||||
|
{
|
||||||
let mut lock = self.map.try_lock();
|
let mut lock = self.map.try_lock();
|
||||||
if let Ok(ref mut cache) = lock {
|
if let Ok(ref mut cache) = lock {
|
||||||
keys.iter().map(|k| cache.get(k).cloned()).collect()
|
Some(keys_iter.map(|k| cache.get(&k).cloned()).collect())
|
||||||
} else {
|
} else {
|
||||||
keys.iter().map(|_| None).collect()
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_values(&self, keys: Vec<K>, values: Vec<V>) {
|
pub fn set_values<I, J>(&self, keys_iter: I, values_iter: J)
|
||||||
|
where
|
||||||
|
I: Iterator<Item = K>,
|
||||||
|
J: Iterator<Item = Option<V>>,
|
||||||
|
{
|
||||||
let mut lock = self.map.try_lock();
|
let mut lock = self.map.try_lock();
|
||||||
if let Ok(ref mut cache) = lock {
|
if let Ok(ref mut cache) = lock {
|
||||||
for (key, value) in keys.into_iter().zip(values) {
|
for (key, value) in keys_iter.zip(values_iter).filter(|(_, v)| v.is_some()) {
|
||||||
cache.insert(key, value);
|
cache.insert(key, value.unwrap());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -212,14 +212,9 @@ impl Model for BPE {
|
|||||||
|
|
||||||
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
|
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
|
||||||
let mut cached_words = match self.dropout {
|
let mut cached_words = match self.dropout {
|
||||||
None => Some(
|
None => self
|
||||||
self.cache.get_values(
|
.cache
|
||||||
&sentence
|
.get_values(sentence.iter().map(|(s, _)| s.clone())),
|
||||||
.iter()
|
|
||||||
.map(|(s, _)| s.to_owned())
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
Some(_) => None, // If using dropout we don't want to use a cached.
|
Some(_) => None, // If using dropout we don't want to use a cached.
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -250,13 +245,8 @@ impl Model for BPE {
|
|||||||
|
|
||||||
// Also update cache
|
// Also update cache
|
||||||
if let Some(cache) = cached_words {
|
if let Some(cache) = cached_words {
|
||||||
let (keys, values) = sentence
|
let keys_iter = sentence.into_iter().map(|(s, _)| s);
|
||||||
.into_iter()
|
self.cache.set_values(keys_iter, cache.into_iter());
|
||||||
.zip(cache)
|
|
||||||
.filter(|(_, v)| v.is_some())
|
|
||||||
.map(|(k, v)| (k.0, v.unwrap()))
|
|
||||||
.unzip::<_, _, Vec<String>, Vec<Word>>();
|
|
||||||
self.cache.set_values(keys, values);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(encoded)
|
Ok(encoded)
|
||||||
|
Reference in New Issue
Block a user