diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index d46118b0..6cb34fe1 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,5 +1,5 @@ use super::{Cache, Error, Pair, Word}; -use crate::tokenizer::{Model, Token}; +use crate::tokenizer::{Model, Result, Token}; use serde_json::Value; use std::{ collections::HashMap, @@ -37,7 +37,7 @@ impl BPE { BPE::new(HashMap::new(), HashMap::new(), HashMap::new()) } - pub fn from_files(vocab: &str, merges: &str) -> Result { + pub fn from_files(vocab: &str, merges: &str) -> Result { // Read vocab.json let vocab_file = File::open(vocab)?; let mut vocab_file = BufReader::new(vocab_file); @@ -55,7 +55,7 @@ impl BPE { } } } - _ => return Err(Error::BadVocabulary), + _ => return Err(Box::new(Error::BadVocabulary)), }; // Read merges file @@ -100,9 +100,9 @@ impl Model for BPE { self.vocab.len() } - fn tokenize(&self, sentence: Vec) -> Vec { + fn tokenize(&self, sentence: Vec) -> Result> { if sentence.len() == 0 { - return vec![]; + return Ok(vec![]); } let mut encoded: Vec = Vec::with_capacity(sentence.len()); @@ -181,15 +181,16 @@ impl Model for BPE { .unzip::<_, _, Vec, Vec>(); self.cache.set_values(keys, values); - encoded + Ok(encoded) } - fn decode(&self, ids: Vec) -> Vec { - ids.into_iter() + fn decode(&self, ids: Vec) -> Result> { + Ok(ids + .into_iter() .map(|id| self.vocab_r.get(&id)) .filter(|token| token.is_some()) .map(|id| id.unwrap().clone()) - .collect() + .collect()) } fn token_to_id(&self, token: &str) -> Option { diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index d9310aa2..b1234e33 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -4,10 +4,9 @@ //! In charge of training a BPE model //! use super::{Pair, Word, BPE}; -use crate::tokenizer::{Model, Trainer}; +use crate::tokenizer::{Model, Result, Trainer}; use std::{ collections::{HashMap, HashSet}, - error::Error, time::Instant, }; @@ -50,10 +49,7 @@ impl BpeTrainer { impl Trainer for BpeTrainer { /// Train a BPE model - fn train( - &self, - word_counts: HashMap, - ) -> Result, Box> { + fn train(&self, word_counts: HashMap) -> Result> { let mut words: Vec = vec![]; let mut counts: Vec = vec![]; let mut word_to_id: HashMap = HashMap::new(); diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index c4bef007..04b03f8c 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,4 +1,4 @@ -use crate::tokenizer::{Decoder, PreTokenizer}; +use crate::tokenizer::{Decoder, PreTokenizer, Result}; use regex::Regex; use std::collections::HashMap; @@ -32,8 +32,9 @@ lazy_static! { pub struct ByteLevel; impl PreTokenizer for ByteLevel { - fn pre_tokenize(&self, s: &str) -> Vec { - RE.captures_iter(s) + fn pre_tokenize(&self, s: &str) -> Result> { + Ok(RE + .captures_iter(s) .map(|capture| { let capture = capture.get(0).unwrap(); let start = capture.start(); @@ -78,20 +79,20 @@ impl PreTokenizer for ByteLevel { .map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap()) .collect() }) - .collect() + .collect()) } } impl Decoder for ByteLevel { - fn decode(&self, tokens: Vec) -> String { - String::from_utf8_lossy( + fn decode(&self, tokens: Vec) -> Result { + Ok(String::from_utf8_lossy( &tokens .join("") .chars() .map(|c| CHAR_BYTES[&(c as u32)]) .collect::>(), ) - .into_owned() + .into_owned()) } } @@ -104,7 +105,9 @@ mod tests { fn pre_tokenization() { let pre_tok = ByteLevel; assert_eq!( - pre_tok.pre_tokenize("Hello my friend, how is your day going?"), + pre_tok + .pre_tokenize("Hello my friend, how is your day going?") + .unwrap(), vec![ "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?" ] @@ -116,14 +119,17 @@ mod tests { let decoder = ByteLevel; assert_eq!( "Hello my friend, how is your day going?", - decoder.decode( - vec![ - "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?" - ] - .into_iter() - .map(|s| s.into()) - .collect::>() - ) + decoder + .decode( + vec![ + "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", + "?" + ] + .into_iter() + .map(|s| s.into()) + .collect::>() + ) + .unwrap() ); } @@ -141,13 +147,13 @@ mod tests { let bl = ByteLevel; for sample in samples { - let pre_tokenized = bl.pre_tokenize(&sample); + let pre_tokenized = bl.pre_tokenize(&sample).unwrap(); let separated_tokens = pre_tokenized .into_iter() .map(|token| token.split("").map(|t| t.into()).collect::>()) .flatten() .collect::>(); - assert_eq!(sample, bl.decode(separated_tokens)); + assert_eq!(sample, bl.decode(separated_tokens).unwrap()); } } }