diff --git a/tokenizers/src/tokenizer.rs b/tokenizers/src/tokenizer.rs index 81b300aa..ba68b213 100644 --- a/tokenizers/src/tokenizer.rs +++ b/tokenizers/src/tokenizer.rs @@ -15,26 +15,27 @@ use rayon::prelude::*; use std::{ collections::HashMap, - error::Error, fs::File, io::{BufRead, BufReader}, }; +pub type Result = std::result::Result>; + /// A Normalizer takes care of pre-processing strings pub trait Normalizer { - fn normalize(&self, s: String) -> String; + fn normalize(&self, s: String) -> Result; } /// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model pub trait PreTokenizer { // TODO: Should return offsets with each substring - fn pre_tokenize(&self, s: &str) -> Vec; + fn pre_tokenize(&self, s: &str) -> Result>; } /// Represents a `Model` used during Tokenization (Like BPE or Word or Unigram) pub trait Model { - fn tokenize(&self, tokens: Vec) -> Vec; - fn decode(&self, ids: Vec) -> Vec; + fn tokenize(&self, tokens: Vec) -> Result>; + fn decode(&self, ids: Vec) -> Result>; fn token_to_id(&self, token: &str) -> Option; fn id_to_token(&self, id: u32) -> Option; fn get_vocab_size(&self) -> usize; @@ -43,18 +44,18 @@ pub trait Model { /// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer. /// Truncating, Padding, etc... are PostProcessor steps pub trait PostProcessor { - fn process(&self, encoding: Encoding, pair_encoding: Option) -> Encoding; + fn process(&self, encoding: Encoding, pair_encoding: Option) -> Result; } /// A Decoder has the responsibility to merge the given Vec in a String pub trait Decoder { - fn decode(&self, tokens: Vec) -> String; + fn decode(&self, tokens: Vec) -> Result; } /// A Trainer has the responsibility to train a Model. We feed it with lines/sentences /// and it returns a Model when done. pub trait Trainer: Sync { - fn train(&self, words: HashMap) -> Result, Box>; + fn train(&self, words: HashMap) -> Result>; fn process_tokens(&self, words: &mut HashMap, tokens: Vec); } @@ -82,6 +83,24 @@ pub struct Encoding { offsets: Vec<(usize, usize)>, } impl Encoding { + pub fn new( + original: String, + normalized: String, + ids: Vec, + type_ids: Vec, + tokens: Vec, + offsets: Vec<(usize, usize)>, + ) -> Self { + Encoding { + original, + normalized, + ids, + type_ids, + tokens, + offsets, + } + } + pub fn get_original(&self) -> &str { &self.original } @@ -183,19 +202,19 @@ impl Tokenizer { } /// Encode the given sentence - pub fn encode(&self, input: EncodeInput) -> Encoding { - let generate_output = move |sentence: String, type_id: u32| -> Encoding { + pub fn encode(&self, input: EncodeInput) -> Result { + let generate_output = move |sentence: String, type_id: u32| -> Result { // 1. Normalization // TODO: Make sure we have the offsets update necessary to go from the original text to // the normalized one let original = sentence.clone(); - let normalized = self.normalize(sentence); + let normalized = self.normalize(sentence)?; // 2. Pre tokenization - let pre_tokenized = self.pre_tokenize(&normalized); + let pre_tokenized = self.pre_tokenize(&normalized)?; // 3. Model - let output = self.model.tokenize(pre_tokenized); + let output = self.model.tokenize(pre_tokenized)?; let length = output.len(); let (ids, tokens, offsets) = output.into_iter().fold( @@ -212,14 +231,14 @@ impl Tokenizer { }, ); - Encoding { + Ok(Encoding { original, normalized, ids, type_ids: vec![type_id; length], tokens, offsets, - } + }) }; let (sentence, pair) = match input { @@ -227,15 +246,18 @@ impl Tokenizer { EncodeInput::Dual(s1, s2) => (s1, Some(s2)), }; - let encoding = generate_output(sentence, 0); - let pair_encoding = pair.map(|pair| generate_output(pair, 1)); + let encoding = generate_output(sentence, 0)?; + let pair_encoding = match pair { + Some(pair) => Some(generate_output(pair, 1)?), + None => None, + }; // 4. Post processing self.post_process(encoding, pair_encoding) } /// Encode all the sentences in parallel, using multiple threads - pub fn encode_batch(&self, inputs: Vec) -> Vec { + pub fn encode_batch(&self, inputs: Vec) -> Result> { inputs .into_par_iter() .map(|input| self.encode(input)) @@ -243,18 +265,18 @@ impl Tokenizer { } /// Decode the given ids, back to a String - pub fn decode(&self, ids: Vec) -> String { - let tokens = self.model.decode(ids); + pub fn decode(&self, ids: Vec) -> Result { + let tokens = self.model.decode(ids)?; if let Some(decoder) = &self.decoder { decoder.decode(tokens) } else { - tokens.join(" ") + Ok(tokens.join(" ")) } } /// Decode all sentences in parallel - pub fn decode_batch(&self, sentences: Vec>) -> Vec { + pub fn decode_batch(&self, sentences: Vec>) -> Result> { sentences .into_par_iter() .map(|sentence| self.decode(sentence)) @@ -262,14 +284,10 @@ impl Tokenizer { } /// Train a model and replace our current Model, using the given Trainer - pub fn train( - &mut self, - trainer: &Box, - files: Vec, - ) -> Result<(), Box> { + pub fn train(&mut self, trainer: &Box, files: Vec) -> Result<()> { let results = files .par_iter() - .map(|file| -> std::io::Result> { + .map(|file| -> Result> { let mut words = HashMap::new(); let file: std::fs::File = File::open(file)?; @@ -277,8 +295,8 @@ impl Tokenizer { for line in file.lines() { let line = line?; - let normalized = self.normalize(line); - let pre_tokenized = self.pre_tokenize(&normalized); + let normalized = self.normalize(line)?; + let pre_tokenized = self.pre_tokenize(&normalized)?; trainer.process_tokens(&mut words, pre_tokenized); } @@ -302,30 +320,34 @@ impl Tokenizer { } /// PreTokenization logic, handling the case where there is no PreTokenizer set - fn pre_tokenize(&self, sentence: &str) -> Vec { + fn pre_tokenize(&self, sentence: &str) -> Result> { match &self.pre_tokenizer { - None => vec![sentence.to_owned()], + None => Ok(vec![sentence.to_owned()]), Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence), } } /// Normalization logic, go through all normalizers - fn normalize(&self, sentence: String) -> String { + fn normalize(&self, sentence: String) -> Result { if let Some(normalizer) = &self.normalizer { normalizer.normalize(sentence) } else { - sentence.to_owned() + Ok(sentence.to_owned()) } } /// Post processing logic, handling the case where there is no PostProcessor set - fn post_process(&self, encoding: Encoding, pair_encoding: Option) -> Encoding { + fn post_process( + &self, + encoding: Encoding, + pair_encoding: Option, + ) -> Result { if let Some(processor) = &self.post_processor { processor.process(encoding, pair_encoding) } else { match pair_encoding { - None => encoding, - Some(pair) => Encoding { + None => Ok(encoding), + Some(pair) => Ok(Encoding { original: format!("{}{}", encoding.original, pair.original), normalized: format!("{}{}", encoding.normalized, pair.normalized), ids: [&encoding.ids[..], &pair.ids[..]].concat(), @@ -345,7 +367,7 @@ impl Tokenizer { .collect::>(), ] .concat(), - }, + }), } } }