diff --git a/tokenizers/src/cli.rs b/tokenizers/src/cli.rs index 1dbd3ccd..5ef60825 100644 --- a/tokenizers/src/cli.rs +++ b/tokenizers/src/cli.rs @@ -6,7 +6,7 @@ use clap::{App, AppSettings, Arg, ArgMatches, SubCommand}; use std::io::{self, BufRead, Write}; use tokenizers::models::bpe::{Error, BPE}; use tokenizers::pre_tokenizers::byte_level::ByteLevel; -use tokenizers::tokenizer::Tokenizer; +use tokenizers::tokenizer::{EncodeInput, Tokenizer}; fn shell(matches: &ArgMatches) -> Result<(), Error> { let vocab = matches @@ -33,21 +33,12 @@ fn shell(matches: &ArgMatches) -> Result<(), Error> { let buffer = buffer.trim_end(); let timer = std::time::Instant::now(); - let encoded = tokenizer.encode(buffer); + let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned())); let elapsed = timer.elapsed(); println!("\nInput:\t\t{}", buffer); - println!( - "Tokens:\t\t{:?}", - encoded.iter().map(|t| &t.value).collect::>() - ); - println!( - "IDs:\t\t{:?}", - encoded.iter().map(|t| t.id).collect::>() - ); - println!( - "Offsets:\t{:?}", - encoded.iter().map(|t| t.offsets).collect::>() - ); + println!("Tokens:\t\t{:?}", encoded.get_tokens()); + println!("IDs:\t\t{:?}", encoded.get_ids()); + println!("Offsets:\t{:?}", encoded.get_offsets()); println!("Tokenized in {:?}", elapsed); } } diff --git a/tokenizers/src/tokenizer.rs b/tokenizers/src/tokenizer.rs index 3dbccc9c..6f720cf0 100644 --- a/tokenizers/src/tokenizer.rs +++ b/tokenizers/src/tokenizer.rs @@ -22,12 +22,12 @@ use std::{ /// A Normalizer takes care of pre-processing strings pub trait Normalizer { - // TODO: Use Cow here to avoid useless allocation if nothing is modified - fn normalize(&self, s: &str) -> String; + fn normalize(&self, s: String) -> String; } /// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model pub trait PreTokenizer { + // TODO: Should return offsets with each substring fn pre_tokenize(&self, s: &str) -> Vec; } @@ -42,7 +42,7 @@ pub trait Model { /// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer. /// Truncating, Padding, etc... are PostProcessor steps pub trait PostProcessor { - fn process(&self, tokens: Vec) -> Vec; + fn process(&self, encoding: Encoding, pair_encoding: Option) -> Encoding; } /// A Decoder has the responsibility to merge the given Vec in a String @@ -57,13 +57,12 @@ pub trait Trainer: Sync { fn process_tokens(&self, words: &mut HashMap, tokens: Vec); } -/// A Token represents the output of the Tokenizer +/// A Token #[derive(Debug, PartialEq)] pub struct Token { pub id: u32, pub value: String, pub offsets: (usize, usize), - // TODO: Find out the best way to define the customizable part (For post processing steps) } impl Token { pub fn new(id: u32, value: String, offsets: (usize, usize)) -> Self { @@ -71,16 +70,57 @@ impl Token { } } +/// The Encoding struct represents the output of the Tokenizer +#[derive(Default)] +pub struct Encoding { + original: String, + normalized: String, + ids: Vec, + type_ids: Vec, + tokens: Vec, + offsets: Vec<(usize, usize)>, +} +impl Encoding { + pub fn get_original(&self) -> &str { + &self.original + } + + pub fn get_normalized(&self) -> &str { + &self.normalized + } + + pub fn get_tokens(&self) -> &[String] { + &self.tokens[..] + } + + pub fn get_ids(&self) -> &[u32] { + &self.ids + } + + pub fn get_type_ids(&self) -> &[u32] { + &self.type_ids + } + + pub fn get_offsets(&self) -> &[(usize, usize)] { + &self.offsets + } +} + +pub enum EncodeInput { + Single(String), + Dual(String, String), +} + /// /// ## Tokenizer /// /// A Tokenizer is capable of encoding/decoding any text /// pub struct Tokenizer { - normalizers: Vec>, + normalizer: Option>, pre_tokenizer: Option>, model: Box, - post_processors: Vec>, + post_processor: Option>, decoder: Option>, } @@ -88,17 +128,17 @@ impl Tokenizer { /// Instanciate a new Tokenizer, with the given Model pub fn new(model: Box) -> Self { Tokenizer { - normalizers: vec![], + normalizer: None, pre_tokenizer: None, model, - post_processors: vec![], + post_processor: None, decoder: None, } } /// Set the normalizers - pub fn with_normalizers(&mut self, normalizers: Vec>) -> &Self { - self.normalizers = normalizers; + pub fn with_normalizers(&mut self, normalizer: Box) -> &Self { + self.normalizer = Some(normalizer); self } @@ -108,12 +148,9 @@ impl Tokenizer { self } - /// Set the post processors - pub fn with_post_processors( - &mut self, - post_processors: Vec>, - ) -> &Self { - self.post_processors = post_processors; + /// Set the post processor + pub fn with_post_processor(&mut self, post_processor: Box) -> &Self { + self.post_processor = Some(post_processor); self } @@ -140,18 +177,62 @@ impl Tokenizer { } /// Encode the given sentence - pub fn encode(&self, sentence: &str) -> Vec { - let normalized = self.normalize(sentence); - let pre_tokenized = self.pre_tokenize(&normalized); + pub fn encode(&self, input: EncodeInput) -> Encoding { + let generate_output = move |sentence: String, type_id: u32| -> Encoding { + // 1. Normalization + // TODO: Make sure we have the offsets update necessary to go from the original text to + // the normalized one + let original = sentence.clone(); + let normalized = self.normalize(sentence); - self.model.tokenize(pre_tokenized) + // 2. Pre tokenization + let pre_tokenized = self.pre_tokenize(&normalized); + + // 3. Model + let output = self.model.tokenize(pre_tokenized); + let length = output.len(); + + let (ids, tokens, offsets) = output.into_iter().fold( + ( + Vec::with_capacity(length), + Vec::with_capacity(length), + Vec::with_capacity(length), + ), + |(mut ids, mut tokens, mut offsets), t| { + ids.push(t.id); + tokens.push(t.value); + offsets.push(t.offsets); + (ids, tokens, offsets) + }, + ); + + Encoding { + original, + normalized, + ids, + type_ids: vec![type_id; length], + tokens, + offsets, + } + }; + + let (sentence, pair) = match input { + EncodeInput::Single(s1) => (s1, None), + EncodeInput::Dual(s1, s2) => (s1, Some(s2)), + }; + + let encoding = generate_output(sentence, 0); + let pair_encoding = pair.map(|pair| generate_output(pair, 1)); + + // 4. Post processing + self.post_process(encoding, pair_encoding) } /// Encode all the sentences in parallel, using multiple threads - pub fn encode_batch(&self, sentences: Vec<&str>) -> Vec> { - sentences - .par_iter() - .map(|sentence| self.encode(sentence)) + pub fn encode_batch(&self, inputs: Vec) -> Vec { + inputs + .into_par_iter() + .map(|input| self.encode(input)) .collect() } @@ -190,7 +271,7 @@ impl Tokenizer { for line in file.lines() { let line = line?; - let normalized = self.normalize(&line); + let normalized = self.normalize(line); let pre_tokenized = self.pre_tokenize(&normalized); trainer.process_tokens(&mut words, pre_tokenized); } @@ -223,11 +304,43 @@ impl Tokenizer { } /// Normalization logic, go through all normalizers - fn normalize(&self, sentence: &str) -> String { - if self.normalizers.len() == 0 { - sentence.to_owned() + fn normalize(&self, sentence: String) -> String { + if let Some(normalizer) = &self.normalizer { + normalizer.normalize(sentence) } else { - unimplemented!("Normalization has not been implemented yet") + sentence.to_owned() + } + } + + /// Post processing logic, handling the case where there is no PostProcessor set + fn post_process(&self, encoding: Encoding, pair_encoding: Option) -> Encoding { + if let Some(processor) = &self.post_processor { + processor.process(encoding, pair_encoding) + } else { + match pair_encoding { + None => encoding, + Some(pair) => Encoding { + original: format!("{}{}", encoding.original, pair.original), + normalized: format!("{}{}", encoding.normalized, pair.normalized), + ids: [&encoding.ids[..], &pair.ids[..]].concat(), + type_ids: [&encoding.type_ids[..], &pair.type_ids[..]].concat(), + tokens: [&encoding.tokens[..], &pair.tokens[..]].concat(), + offsets: [ + &encoding.offsets[..], + &pair + .offsets + .into_iter() + .map(|(start, end)| { + ( + start + encoding.original.len(), + end + encoding.original.len(), + ) + }) + .collect::>(), + ] + .concat(), + }, + } } } }