diff --git a/tokenizers/CHANGELOG.md b/tokenizers/CHANGELOG.md index 1f8195a3..7599e8bc 100644 --- a/tokenizers/CHANGELOG.md +++ b/tokenizers/CHANGELOG.md @@ -7,6 +7,8 @@ a high number of files as it avoids having too many progress bars on screen. - `ByteLevel` is also a `PostProcessor` now and handles trimming the offsets if activated. This avoids the unintuitive inclusion of the whitespaces in the produced offsets, even if these whitespaces are part of the actual token. +- `encode` and `encode_batch` now take a new argument, specifying whether we should add the +special tokens. ## Fixes: - Fix some issues with the offsets being wrong with the `ByteLevel` BPE: diff --git a/tokenizers/src/cli.rs b/tokenizers/src/cli.rs index 17df391f..a72cf1a0 100644 --- a/tokenizers/src/cli.rs +++ b/tokenizers/src/cli.rs @@ -45,7 +45,7 @@ fn shell(matches: &ArgMatches) -> Result<()> { let buffer = buffer.trim_end(); let timer = std::time::Instant::now(); - let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned()))?; + let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned()), false)?; let elapsed = timer.elapsed(); println!("\nInput:\t\t{}", buffer); println!("Tokens:\t\t{:?}", encoded.get_tokens()); diff --git a/tokenizers/src/lib.rs b/tokenizers/src/lib.rs index 3c075fb5..cc913e37 100644 --- a/tokenizers/src/lib.rs +++ b/tokenizers/src/lib.rs @@ -37,7 +37,7 @@ //! //! let mut tokenizer = Tokenizer::new(Box::new(bpe)); //! -//! let encoding = tokenizer.encode(EncodeInput::Single("Hey there!".into()))?; +//! let encoding = tokenizer.encode(EncodeInput::Single("Hey there!".into()), false)?; //! println!("{:?}", encoding.get_tokens()); //! //! Ok(()) diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 5f0928e6..59b6e947 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -193,7 +193,12 @@ impl PostProcessor for ByteLevel { 0 } - fn process(&self, mut encoding: Encoding, pair_encoding: Option) -> Result { + fn process( + &self, + mut encoding: Encoding, + pair_encoding: Option, + _add_special_tokens: bool, + ) -> Result { let process_offsets = |encoding: &mut Encoding| { if !self.trim_offsets { return; @@ -423,13 +428,18 @@ mod tests { ); let bytelevel = ByteLevel::default().trim_offsets(true); - assert_eq!(expected, bytelevel.process(start.clone(), None).unwrap()); + assert_eq!( + expected, + bytelevel.process(start.clone(), None, false).unwrap() + ); let mut pair_expected = expected.clone(); pair_expected.merge_with(expected); assert_eq!( pair_expected, - bytelevel.process(start.clone(), Some(start)).unwrap() + bytelevel + .process(start.clone(), Some(start), false) + .unwrap() ); } } diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index dbfc026a..1cd7c535 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -20,7 +20,16 @@ impl PostProcessor for BertProcessing { } } - fn process(&self, mut encoding: Encoding, pair_encoding: Option) -> Result { + fn process( + &self, + mut encoding: Encoding, + pair_encoding: Option, + add_special_tokens: bool, + ) -> Result { + if !add_special_tokens { + return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens); + } + let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat(); let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat(); let tokens = [ diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index 68145fac..5a0f5bc7 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -20,7 +20,16 @@ impl PostProcessor for RobertaProcessing { } } - fn process(&self, mut encoding: Encoding, pair_encoding: Option) -> Result { + fn process( + &self, + mut encoding: Encoding, + pair_encoding: Option, + add_special_tokens: bool, + ) -> Result { + if !add_special_tokens { + return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens); + } + let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat(); let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat(); let tokens = [ diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 198a784f..aff653e8 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -55,7 +55,27 @@ pub trait PostProcessor { /// Returns the number of tokens that will be added during the processing step fn added_tokens(&self, is_pair: bool) -> usize; /// Process both encodings and returns a new merged one - fn process(&self, encoding: Encoding, pair_encoding: Option) -> Result; + fn process( + &self, + encoding: Encoding, + pair_encoding: Option, + add_special_tokens: bool, + ) -> Result; +} +impl dyn PostProcessor { + pub fn default_process( + mut encoding: Encoding, + pair_encoding: Option, + _add_special_tokens: bool, + ) -> Result { + match pair_encoding { + None => Ok(encoding), + Some(pair) => { + encoding.merge_with(pair); + Ok(encoding) + } + } + } } /// A `Decoder` has the responsibility to merge the given `Vec` in a `String`. @@ -277,7 +297,7 @@ impl Tokenizer { } /// Encode the given sentence - pub fn encode(&self, input: EncodeInput) -> Result { + pub fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> Result { let generate_output = move |sentence: String, type_id: u32| -> Result { // First we need to split into as many sequences as needed to avoid splitting // on our added tokens @@ -362,14 +382,18 @@ impl Tokenizer { }; // 4. Post processing - self.post_process(encoding, pair_encoding) + self.post_process(encoding, pair_encoding, add_special_tokens) } /// Encode all the sentences in parallel, using multiple threads - pub fn encode_batch(&self, inputs: Vec) -> Result> { + pub fn encode_batch( + &self, + inputs: Vec, + add_special_tokens: bool, + ) -> Result> { let encodings = inputs .into_par_iter() - .map(|input| self.encode(input)) + .map(|input| self.encode(input, add_special_tokens)) .collect::>>()?; if let Some(params) = &self.padding { @@ -513,9 +537,10 @@ impl Tokenizer { &self, encoding: Encoding, pair_encoding: Option, + add_special_tokens: bool, ) -> Result { // 1. First we truncate if needed - let (mut encoding, pair_encoding) = { + let (encoding, pair_encoding) = { if let Some(trunc) = &self.trunc { let n_added_tokens = if let Some(processor) = &self.post_processor { processor.added_tokens(pair_encoding.is_some()) @@ -523,7 +548,7 @@ impl Tokenizer { 0 }; - if n_added_tokens > 0 { + if add_special_tokens && n_added_tokens > 0 { let params = TruncationParams { max_length: trunc.max_length - n_added_tokens, ..*trunc @@ -539,15 +564,9 @@ impl Tokenizer { // 2. Then We post process let mut final_encoding = if let Some(processor) = &self.post_processor { - processor.process(encoding, pair_encoding)? + processor.process(encoding, pair_encoding, add_special_tokens)? } else { - match pair_encoding { - None => encoding, - Some(pair) => { - encoding.merge_with(pair); - encoding - } - } + PostProcessor::default_process(encoding, pair_encoding, add_special_tokens)? }; // 3. Then we pad if needed