Merge pull request #128 from huitseeker/warts

Maintenance : simplifications & update
This commit is contained in:
MOI Anthony
2020-02-05 12:28:22 -05:00
committed by GitHub
9 changed files with 19 additions and 47 deletions

View File

@@ -502,11 +502,7 @@ declare_types! {
let guard = cx.lock(); let guard = cx.lock();
let borrowed = this.borrow(&guard); let borrowed = this.borrow(&guard);
let normalizer = borrowed.tokenizer.get_normalizer(); let normalizer = borrowed.tokenizer.get_normalizer();
if let Some(normalizer) = normalizer { normalizer.map(|normalizer| { Container::from_ref(normalizer) })
Some(Container::from_ref(normalizer))
} else {
None
}
}; };
if let Some(normalizer) = normalizer { if let Some(normalizer) = normalizer {
@@ -561,11 +557,7 @@ declare_types! {
let guard = cx.lock(); let guard = cx.lock();
let borrowed = this.borrow(&guard); let borrowed = this.borrow(&guard);
let pretok = borrowed.tokenizer.get_pre_tokenizer(); let pretok = borrowed.tokenizer.get_pre_tokenizer();
if let Some(pretok) = pretok { pretok.map(|pretok| { Container::from_ref(pretok) })
Some(Container::from_ref(pretok))
} else {
None
}
}; };
if let Some(pretok) = pretok { if let Some(pretok) = pretok {
@@ -620,11 +612,7 @@ declare_types! {
let guard = cx.lock(); let guard = cx.lock();
let borrowed = this.borrow(&guard); let borrowed = this.borrow(&guard);
let processor = borrowed.tokenizer.get_post_processor(); let processor = borrowed.tokenizer.get_post_processor();
if let Some(processor) = processor { processor.map(|processor| { Container::from_ref(processor) })
Some(Container::from_ref(processor))
} else {
None
}
}; };
if let Some(processor) = processor { if let Some(processor) = processor {
@@ -679,11 +667,7 @@ declare_types! {
let guard = cx.lock(); let guard = cx.lock();
let borrowed = this.borrow(&guard); let borrowed = this.borrow(&guard);
let decoder = borrowed.tokenizer.get_decoder(); let decoder = borrowed.tokenizer.get_decoder();
if let Some(decoder) = decoder { decoder.map(|decoder| { Container::from_ref(decoder) })
Some(Container::from_ref(decoder))
} else {
None
}
}; };
if let Some(decoder) = decoder { if let Some(decoder) = decoder {

View File

@@ -23,9 +23,6 @@ impl std::error::Error for PyError {}
pub struct ToPyResult<T>(pub Result<T>); pub struct ToPyResult<T>(pub Result<T>);
impl<T> std::convert::Into<PyResult<T>> for ToPyResult<T> { impl<T> std::convert::Into<PyResult<T>> for ToPyResult<T> {
fn into(self) -> PyResult<T> { fn into(self) -> PyResult<T> {
match self.0 { self.0.map_err(|e| { exceptions::Exception::py_err(format!("{}", e)) })
Ok(o) => Ok(o),
Err(e) => Err(exceptions::Exception::py_err(format!("{}", e))),
}
} }
} }

View File

@@ -40,7 +40,7 @@ serde_json = "1.0"
clap = "2.33.0" clap = "2.33.0"
unicode-normalization-alignments = "0.1.12" unicode-normalization-alignments = "0.1.12"
unicode_categories = "0.1.1" unicode_categories = "0.1.1"
indicatif = "0.13.0" indicatif = "0.14.0"
[dev-dependencies] [dev-dependencies]
criterion = "0.3.0" criterion = "0.3.0"

View File

@@ -153,10 +153,7 @@ impl Clone for BPE {
// `Clone` can't be derive because it's not implemented for `Cache`. // `Clone` can't be derive because it's not implemented for `Cache`.
// To keep things simple when we clone, the new BPE will start with a fresh cache. // To keep things simple when we clone, the new BPE will start with a fresh cache.
fn clone(&self) -> Self { fn clone(&self) -> Self {
let fresh_cache = match self.cache { let fresh_cache = self.cache.as_ref().map(|cache| cache.fresh());
Some(ref cache) => Some(cache.fresh()),
None => None,
};
Self { Self {
vocab: self.vocab.clone(), vocab: self.vocab.clone(),
vocab_r: self.vocab_r.clone(), vocab_r: self.vocab_r.clone(),
@@ -359,10 +356,10 @@ impl Model for BPE {
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len()); let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
let mut cached_words = match self.dropout { let mut cached_words = match self.dropout {
None => match self.cache { None => self
Some(ref cache) => cache.get_values(sentence.iter().map(|(s, _)| s.clone())), .cache
None => None, .as_ref()
}, .and_then(|cache| cache.get_values(sentence.iter().map(|(s, _)| s.clone()))),
Some(_) => None, // If using dropout we don't want to use the cache. Some(_) => None, // If using dropout we don't want to use the cache.
}; };
let mut should_update_cache = false; let mut should_update_cache = false;
@@ -446,10 +443,9 @@ impl Model for BPE {
merges_file.write_all( merges_file.write_all(
&merges &merges
.into_iter() .into_iter()
.map(|(pair, _)| { .flat_map(|(pair, _)| {
format!("{} {}\n", self.vocab_r[&pair.0], self.vocab_r[&pair.1]).into_bytes() format!("{} {}\n", self.vocab_r[&pair.0], self.vocab_r[&pair.1]).into_bytes()
}) })
.flatten()
.collect::<Vec<_>>()[..], .collect::<Vec<_>>()[..],
)?; )?;

View File

@@ -261,8 +261,7 @@ impl Model for WordPiece {
vocab_file.write_all( vocab_file.write_all(
&vocab &vocab
.into_iter() .into_iter()
.map(|(token, _)| format!("{}\n", token).as_bytes().to_owned()) .flat_map(|(token, _)| format!("{}\n", token).as_bytes().to_owned())
.flatten()
.collect::<Vec<_>>()[..], .collect::<Vec<_>>()[..],
)?; )?;

View File

@@ -194,9 +194,8 @@ mod tests {
for sample in samples { for sample in samples {
let pre_tokenized = bl.pre_tokenize(&sample).unwrap(); let pre_tokenized = bl.pre_tokenize(&sample).unwrap();
let separated_tokens = pre_tokenized let separated_tokens = pre_tokenized
.into_iter() .iter()
.map(|(token, _)| token.split("").map(|t| t.into()).collect::<Vec<_>>()) .flat_map(|(token, _)| token.split("").map(|t| t.into()))
.flatten()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert_eq!(sample, bl.decode(separated_tokens).unwrap()); assert_eq!(sample, bl.decode(separated_tokens).unwrap());
} }

View File

@@ -57,9 +57,8 @@ impl PreTokenizer for Metaspace {
impl Decoder for Metaspace { impl Decoder for Metaspace {
fn decode(&self, tokens: Vec<String>) -> Result<String> { fn decode(&self, tokens: Vec<String>) -> Result<String> {
Ok(tokens Ok(tokens
.into_iter() .iter()
.map(|t| t.chars().collect::<Vec<_>>()) .flat_map(|t| t.chars())
.flatten()
.enumerate() .enumerate()
.map(|(i, c)| { .map(|(i, c)| {
if c == self.replacement { if c == self.replacement {

View File

@@ -9,7 +9,7 @@ impl PreTokenizer for Whitespace {
} }
Ok(RE Ok(RE
.captures_iter(s) .captures_iter(s)
.map(|captures| { .flat_map(|captures| {
captures captures
.iter() .iter()
.map(|m| { .map(|m| {
@@ -21,7 +21,6 @@ impl PreTokenizer for Whitespace {
}) })
.collect::<Vec<(String, Offsets)>>() .collect::<Vec<(String, Offsets)>>()
}) })
.flatten()
.collect()) .collect())
} }
} }

View File

@@ -679,7 +679,7 @@ impl Tokenizer {
let mut start_offset = 0; let mut start_offset = 0;
let mut splits = splits let mut splits = splits
.into_iter() .into_iter()
.map(|(start, end)| { .flat_map(|(start, end)| {
let mut splits = vec![]; let mut splits = vec![];
if start_offset < start { if start_offset < start {
splits.push((start_offset, start)); splits.push((start_offset, start));
@@ -689,7 +689,6 @@ impl Tokenizer {
splits splits
}) })
.flatten()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if let Some((_, end)) = splits.iter().last().copied() { if let Some((_, end)) = splits.iter().last().copied() {
if end < sentence.len() { if end < sentence.len() {