diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index e67c3cc6..c34ce429 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -1,56 +1,30 @@ use crate::tokenizer::{Encoding, PostProcessor, Result}; -use crate::utils::{truncate_encodings, TruncationStrategy}; pub struct BertProcessing { - max_len: usize, - trunc_strategy: TruncationStrategy, - trunc_stride: usize, - sep: (String, u32), cls: (String, u32), } impl BertProcessing { - pub fn new( - max_len: usize, - trunc_strategy: TruncationStrategy, - trunc_stride: usize, - sep: (String, u32), - cls: (String, u32), - ) -> Self { - BertProcessing { - max_len, - trunc_strategy, - trunc_stride, - sep, - cls, - } + pub fn new(sep: (String, u32), cls: (String, u32)) -> Self { + BertProcessing { sep, cls } } } impl PostProcessor for BertProcessing { - fn process(&self, encoding: Encoding, pair_encoding: Option) -> Result { - let special_token_len = if pair_encoding.is_some() { 3 } else { 2 }; - let total_len = encoding.get_ids().len() - + pair_encoding - .as_ref() - .map(|e| e.get_ids().len()) - .unwrap_or(0) - + special_token_len; - - let need_trunc = if total_len > self.max_len { - total_len - self.max_len + fn added_tokens( + &self, + _encoding: &Encoding, + pair_encoding: &Option, + ) -> Result { + if pair_encoding.is_some() { + Ok(3) } else { - 0 - }; - let (mut encoding, pair_encoding) = truncate_encodings( - encoding, - pair_encoding, - need_trunc, - self.trunc_strategy, - self.trunc_stride, - )?; + Ok(2) + } + } + fn process(&self, mut encoding: Encoding, pair_encoding: Option) -> Result { // Prepare ids let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat(); let pair_ids = pair_encoding diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index be2c87c6..981f99a9 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -1,4 +1,5 @@ /// The various possible padding directions +#[derive(Debug, Clone)] pub enum PaddingDirection { Left, Right, @@ -155,8 +156,8 @@ impl Encoding { pad_length: usize, pad_id: u32, pad_type_id: u32, - pad_token: String, - direction: PaddingDirection, + pad_token: &str, + direction: &PaddingDirection, ) { if self.ids.len() > pad_length { // We just do nothing if the wanted padding length is smaller than us @@ -166,13 +167,13 @@ impl Encoding { let ids_pad = vec![pad_id; pad_length]; let type_ids_pad = vec![pad_type_id; pad_length]; - let tokens_pad = vec![pad_token; pad_length]; + let tokens_pad = vec![pad_token.to_owned(); pad_length]; let attention_pad = vec![0; pad_length]; let special_pad = vec![1; pad_length]; let offsets_pad = vec![(0, 0); pad_length]; match direction { - Left => { + PaddingDirection::Left => { self.ids = [&ids_pad[..], &self.ids[..]].concat(); self.type_ids = [&type_ids_pad[..], &self.type_ids[..]].concat(); self.tokens = [&tokens_pad[..], &self.tokens[..]].concat(); @@ -181,7 +182,7 @@ impl Encoding { [&special_pad[..], &self.special_tokens_mask[..]].concat(); self.offsets = [&offsets_pad[..], &self.offsets[..]].concat(); } - Right => { + PaddingDirection::Right => { self.ids = [&self.ids[..], &ids_pad[..]].concat(); self.type_ids = [&self.type_ids[..], &type_ids_pad[..]].concat(); self.tokens = [&self.tokens[..], &tokens_pad[..]].concat(); diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 2e079fd3..cad63e38 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -12,6 +12,9 @@ //! - PostProcessor: Takes care of the processing after tokenization. (Like truncating, padding, //! ...) //! +use crate::utils::{ + pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams, +}; use rayon::prelude::*; use std::{ collections::HashMap, @@ -20,7 +23,7 @@ use std::{ }; mod encoding; -pub use encoding::Encoding; +pub use encoding::*; pub type Result = std::result::Result>; @@ -46,6 +49,7 @@ pub trait Model { /// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer. /// Truncating, Padding, etc... are PostProcessor steps pub trait PostProcessor { + fn added_tokens(&self, encoding: &Encoding, pair_encoding: &Option) -> Result; fn process(&self, encoding: Encoding, pair_encoding: Option) -> Result; } @@ -125,6 +129,10 @@ pub struct Tokenizer { added_tokens: HashMap, added_tokens_r: HashMap, split_re: Option, + + // General processing parameters + trunc: Option, + padding: Option, } impl Tokenizer { @@ -139,11 +147,13 @@ impl Tokenizer { added_tokens: HashMap::new(), added_tokens_r: HashMap::new(), split_re: None, + trunc: None, + padding: None, } } - /// Set the normalizers - pub fn with_normalizers(&mut self, normalizer: Box) -> &Self { + /// Set the normalizer + pub fn with_normalizer(&mut self, normalizer: Box) -> &Self { self.normalizer = Some(normalizer); self } @@ -172,6 +182,18 @@ impl Tokenizer { self } + /// Set the truncation parameters + pub fn with_truncation(&mut self, trunc: Option) -> &Self { + self.trunc = trunc; + self + } + + /// Set the padding strategy + pub fn with_padding(&mut self, padding: Option) -> &Self { + self.padding = padding; + self + } + /// Get the size of the vocabulary pub fn get_vocab_size(&self) -> usize { self.model.get_vocab_size() @@ -283,10 +305,17 @@ impl Tokenizer { /// Encode all the sentences in parallel, using multiple threads pub fn encode_batch(&self, inputs: Vec) -> Result> { - inputs + let encodings = inputs .into_par_iter() .map(|input| self.encode(input)) - .collect() + .collect::>>()?; + + if let Some(params) = &self.padding { + // We do the padding here to make sure we handle the batch padding + pad_encodings(encodings, ¶ms) + } else { + Ok(encodings) + } } /// Decode the given ids, back to a String @@ -376,20 +405,61 @@ impl Tokenizer { /// Post processing logic, handling the case where there is no PostProcessor set fn post_process( &self, - mut encoding: Encoding, + encoding: Encoding, pair_encoding: Option, ) -> Result { - if let Some(processor) = &self.post_processor { - processor.process(encoding, pair_encoding) + // 1. First we truncate if needed + let (mut encoding, pair_encoding) = { + if let Some(trunc) = &self.trunc { + let n_added_tokens = if let Some(processor) = &self.post_processor { + processor.added_tokens(&encoding, &pair_encoding)? + } else { + 0 + }; + + if n_added_tokens > 0 { + let params = TruncationParams { + max_length: trunc.max_length - n_added_tokens, + ..*trunc + }; + truncate_encodings(encoding, pair_encoding, ¶ms)? + } else { + truncate_encodings(encoding, pair_encoding, &trunc)? + } + } else { + (encoding, pair_encoding) + } + }; + + // 2. Then We post process + let mut final_encoding = if let Some(processor) = &self.post_processor { + processor.process(encoding, pair_encoding)? } else { match pair_encoding { - None => Ok(encoding), + None => encoding, Some(pair) => { encoding.merge_with(pair); - Ok(encoding) + encoding } } + }; + + // 3. Then we pad if needed + if let Some(params) = &self.padding { + // We can only pad for a given size. If the Strategy is BatchLongest, it will be done + // when we handle a batch + if let PaddingStrategy::Fixed(size) = params.strategy { + final_encoding.pad( + size, + params.pad_id, + params.pad_type_id, + ¶ms.pad_token, + ¶ms.direction, + ); + } } + + Ok(final_encoding) } /// Add the given tokens to the added vocabulary diff --git a/tokenizers/src/utils.rs b/tokenizers/src/utils.rs index 6f45ce5f..dc4709d9 100644 --- a/tokenizers/src/utils.rs +++ b/tokenizers/src/utils.rs @@ -1,15 +1,35 @@ -use crate::tokenizer::{Encoding, Result}; +use crate::tokenizer::{Encoding, PaddingDirection, Result}; + +#[derive(Debug, Clone)] +pub struct TruncationParams { + pub max_length: usize, + pub strategy: TruncationStrategy, + pub stride: usize, +} + +#[derive(Debug, Clone)] +pub struct PaddingParams { + pub strategy: PaddingStrategy, + pub direction: PaddingDirection, + pub pad_id: u32, + pub pad_type_id: u32, + pub pad_token: String, +} + +#[derive(Debug, Clone)] +pub enum PaddingStrategy { + BatchLongest, + Fixed(usize), +} #[derive(Debug)] pub enum Error { - SequenceTooSmall, SecondSequenceNotProvided, } impl std::fmt::Display for Error { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - Error::SequenceTooSmall => write!(fmt, "Truncation error: Sequence is too small"), Error::SecondSequenceNotProvided => { write!(fmt, "Truncation error: Second sequence not provided") } @@ -18,7 +38,7 @@ impl std::fmt::Display for Error { } impl std::error::Error for Error {} -#[derive(Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum TruncationStrategy { LongestFirst, OnlyFirst, @@ -28,19 +48,27 @@ pub enum TruncationStrategy { pub fn truncate_encodings( mut encoding: Encoding, mut pair_encoding: Option, - to_remove: usize, - strategy: TruncationStrategy, - stride: usize, + params: &TruncationParams, ) -> Result<(Encoding, Option)> { - if to_remove == 0 { + if params.max_length == 0 { return Ok((encoding, pair_encoding)); } - match strategy { + match params.strategy { TruncationStrategy::LongestFirst => { + let total_length = encoding.get_ids().len() + + pair_encoding + .as_ref() + .map(|e| e.get_ids().len()) + .unwrap_or(0); + let to_remove = if total_length > params.max_length { + total_length - params.max_length + } else { + 0 + }; + let mut n_first = 0; let mut n_second = 0; - for _ in 0..to_remove { if pair_encoding.is_none() || encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len() @@ -51,13 +79,13 @@ pub fn truncate_encodings( } } - encoding.truncate(encoding.get_ids().len() - n_first, stride); + encoding.truncate(encoding.get_ids().len() - n_first, params.stride); if let Some(encoding) = pair_encoding.as_mut() { - encoding.truncate(encoding.get_ids().len() - n_second, stride); + encoding.truncate(encoding.get_ids().len() - n_second, params.stride); } } TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => { - let target = if strategy == TruncationStrategy::OnlyFirst { + let target = if params.strategy == TruncationStrategy::OnlyFirst { Ok(&mut encoding) } else if let Some(encoding) = pair_encoding.as_mut() { Ok(encoding) @@ -65,13 +93,37 @@ pub fn truncate_encodings( Err(Box::new(Error::SecondSequenceNotProvided)) }?; - if target.get_ids().len() <= to_remove { - return Err(Box::new(Error::SequenceTooSmall)); + if target.get_ids().len() > params.max_length { + target.truncate(params.max_length, params.stride); } - - target.truncate(target.get_ids().len() - to_remove, stride); } } Ok((encoding, pair_encoding)) } + +pub fn pad_encodings( + mut encodings: Vec, + params: &PaddingParams, +) -> Result> { + if encodings.is_empty() { + return Ok(encodings); + } + + let pad_length = match params.strategy { + PaddingStrategy::Fixed(size) => size, + PaddingStrategy::BatchLongest => encodings.iter().map(|e| e.get_ids().len()).max().unwrap(), + }; + + for encoding in encodings.iter_mut() { + encoding.pad( + pad_length, + params.pad_id, + params.pad_type_id, + ¶ms.pad_token, + ¶ms.direction, + ); + } + + Ok(encodings) +}