diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index fea21397..222aa027 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -34,6 +34,10 @@ harness = false name = "bert_benchmark" harness = false +[[bench]] +name = "layout_benchmark" +harness = false + [dependencies] lazy_static = "1.4" rand = "0.8" diff --git a/tokenizers/benches/layout_benchmark.rs b/tokenizers/benches/layout_benchmark.rs new file mode 100644 index 00000000..4738d5fa --- /dev/null +++ b/tokenizers/benches/layout_benchmark.rs @@ -0,0 +1,76 @@ +#[macro_use] +extern crate criterion; + +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::Path; +use std::time::{Duration, Instant}; + +use criterion::black_box; +use criterion::Criterion; +use tokenizers::processors::template::TemplateProcessing; +use tokenizers::{EncodeInput, Encoding, PostProcessor, Tokenizer}; + +/// Simple TemplateProcessing +fn create_processor() -> TemplateProcessing { + TemplateProcessing::builder() + .try_single("[CLS]:0 $A:0 [SEP]:0") + .unwrap() + .try_pair("[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1") + .unwrap() + .special_tokens(vec![("[CLS]", 0), ("[SEP]", 1)]) + .build() + .unwrap() +} + +pub fn bench_layout(c: &mut Criterion) { + let processor = create_processor(); + let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap(); + let mut encodeds: Vec = vec![]; + for line in BufReader::new(File::open(Path::new("data/big.txt")).unwrap()).lines() { + let line: EncodeInput = line.unwrap().into(); + + let encoded: Encoding = tokenizer.encode(line, false).unwrap(); + encodeds.push(encoded); + } + + c.bench_function("TemplateProcessing single encode", |b| { + b.iter_custom(|iters| { + let mut duration = Duration::new(0, 0); + for i in 0..iters as usize { + let encoded_index = i % encodeds.len(); + let encoded: Encoding = encodeds[encoded_index].clone(); + + let start = Instant::now(); + let _ = black_box(processor.process(encoded, None, false)); + duration = duration.checked_add(start.elapsed()).unwrap(); + } + duration + }) + }); + c.bench_function("TemplateProcessing pair encode", |b| { + b.iter_custom(|iters| { + let mut duration = Duration::new(0, 0); + for i in 0..iters as usize { + let encoded_index = i % encodeds.len(); + let encoded: Encoding = encodeds[encoded_index].clone(); + + let encoded_index2 = (i + 1) % encodeds.len(); + let pair: Encoding = encodeds[encoded_index2].clone(); + + let start = Instant::now(); + let _ = black_box(processor.process(encoded, Some(pair), false)); + duration = duration.checked_add(start.elapsed()).unwrap(); + } + duration + }) + }); +} + +criterion_group! { + name = layout_benches; + config = Criterion::default().sample_size(20); + targets = bench_layout +} + +criterion_main!(layout_benches); diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index d93f0384..39a834cd 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -1,4 +1,4 @@ -use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result}; +use crate::tokenizer::{Encoding, PostProcessor, Result}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; @@ -49,53 +49,11 @@ impl PostProcessor for BertProcessing { return Ok(encodings); } - let (mut encoding, pair_encoding): (Encoding, Option) = match encodings.len() { - 1 => ( - encodings - .pop() - .ok_or(ProcessorError::InvalidEncodingsVecLength)?, - None, - ), - 2 => { - let pair = encodings - .pop() - .ok_or(ProcessorError::InvalidEncodingsVecLength)?; - let encoding = encodings - .pop() - .ok_or(ProcessorError::InvalidEncodingsVecLength)?; - (encoding, Some(pair)) - } - _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)), - }; - - let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); - let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); - let tokens = [ - &[self.cls.0.clone()], - encoding.get_tokens(), - &[self.sep.0.clone()], - ] - .concat(); - let words = [&[None], encoding.get_word_ids(), &[None]].concat(); - let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); - let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat(); - let attention_mask = vec![1; ids.len()]; - - // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain - // the special tokens. - let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]); - let mut new_encoding = Encoding::new( - ids, - type_ids, - tokens, - words, - offsets, - special_tokens, - attention_mask, - encoding - .take_overflowing() - .into_iter() - .map(|encoding| { + let encodings: Vec = encodings + .iter_mut() + .enumerate() + .map(|(i, encoding)| { + if i == 0 { let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); let tokens = [ @@ -110,8 +68,8 @@ impl PostProcessor for BertProcessing { [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat(); let attention_mask = vec![1; ids.len()]; - // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't - // contain the special tokens. + // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain + // the special tokens. let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]); Encoding::new( ids, @@ -121,72 +79,105 @@ impl PostProcessor for BertProcessing { offsets, special_tokens, attention_mask, - vec![], + encoding + .take_overflowing() + .into_iter() + .map(|encoding| { + let ids = + [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); + let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); + let tokens = [ + &[self.cls.0.clone()], + encoding.get_tokens(), + &[self.sep.0.clone()], + ] + .concat(); + let words = [&[None], encoding.get_word_ids(), &[None]].concat(); + let offsets = + [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); + let special_tokens = + [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]] + .concat(); + let attention_mask = vec![1; ids.len()]; + + // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't + // contain the special tokens. + let sequence_ranges = + HashMap::from_iter(vec![(0, 1..ids.len() - 1)]); + Encoding::new( + ids, + type_ids, + tokens, + words, + offsets, + special_tokens, + attention_mask, + vec![], + sequence_ranges, + ) + }) + .collect(), sequence_ranges, ) - }) - .collect(), - sequence_ranges, - ); + } else { + let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat(); + let pair_type_ids = [encoding.get_type_ids(), &[1]].concat(); + let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat(); + let pair_words = [encoding.get_word_ids(), &[None]].concat(); + let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat(); + let pair_special_tokens = + [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); + let pair_attention_mask = vec![1; pair_ids.len()]; - if let Some(mut encoding) = pair_encoding { - let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat(); - let pair_type_ids = [encoding.get_type_ids(), &[1]].concat(); - let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat(); - let pair_words = [encoding.get_word_ids(), &[None]].concat(); - let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat(); - let pair_special_tokens = - [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); - let pair_attention_mask = vec![1; pair_ids.len()]; + // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain + // the special tokens. + let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]); + Encoding::new( + pair_ids, + pair_type_ids, + pair_tokens, + pair_words, + pair_offsets, + pair_special_tokens, + pair_attention_mask, + encoding + .take_overflowing() + .into_iter() + .map(|encoding| { + let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat(); + let pair_type_ids = [encoding.get_type_ids(), &[1]].concat(); + let pair_tokens = + [encoding.get_tokens(), &[self.sep.0.clone()]].concat(); + let pair_words = [encoding.get_word_ids(), &[None]].concat(); + let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat(); + let pair_special_tokens = + [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); + let pair_attention_mask = vec![1; pair_ids.len()]; - // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain - // the special tokens. - let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]); - let new_pair_encoding = Encoding::new( - pair_ids, - pair_type_ids, - pair_tokens, - pair_words, - pair_offsets, - pair_special_tokens, - pair_attention_mask, - encoding - .take_overflowing() - .into_iter() - .map(|encoding| { - let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat(); - let pair_type_ids = [encoding.get_type_ids(), &[1]].concat(); - let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat(); - let pair_words = [encoding.get_word_ids(), &[None]].concat(); - let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat(); - let pair_special_tokens = - [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); - let pair_attention_mask = vec![1; pair_ids.len()]; + // For compatibility with `TemplateProcessing`, the sequence_ranges + // shouldn't contain the special tokens. + let pair_sequence_ranges = + HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]); + Encoding::new( + pair_ids, + pair_type_ids, + pair_tokens, + pair_words, + pair_offsets, + pair_special_tokens, + pair_attention_mask, + vec![], + pair_sequence_ranges, + ) + }) + .collect(), + pair_sequence_ranges, + ) + } + }) + .collect(); - // For compatibility with `TemplateProcessing`, the sequence_ranges - // shouldn't contain the special tokens. - let pair_sequence_ranges = - HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]); - Encoding::new( - pair_ids, - pair_type_ids, - pair_tokens, - pair_words, - pair_offsets, - pair_special_tokens, - pair_attention_mask, - vec![], - pair_sequence_ranges, - ) - }) - .collect(), - pair_sequence_ranges, - ); - - new_encoding.merge_with(new_pair_encoding, false); - } - - Ok(vec![new_encoding]) + Ok(encodings) } } diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index 41be29b5..ab83e462 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -1,5 +1,5 @@ use crate::processors::byte_level::process_offsets; -use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result}; +use crate::tokenizer::{Encoding, PostProcessor, Result}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; @@ -74,53 +74,11 @@ impl PostProcessor for RobertaProcessing { return Ok(encodings); } - let (mut encoding, pair_encoding): (Encoding, Option) = match encodings.len() { - 1 => ( - encodings - .pop() - .ok_or(ProcessorError::InvalidEncodingsVecLength)?, - None, - ), - 2 => { - let pair = encodings - .pop() - .ok_or(ProcessorError::InvalidEncodingsVecLength)?; - let encoding = encodings - .pop() - .ok_or(ProcessorError::InvalidEncodingsVecLength)?; - (encoding, Some(pair)) - } - _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)), - }; - - let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); - let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); - let tokens = [ - &[self.cls.0.clone()], - encoding.get_tokens(), - &[self.sep.0.clone()], - ] - .concat(); - let words = [&[None], encoding.get_word_ids(), &[None]].concat(); - let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); - let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat(); - let attention_mask = vec![1; ids.len()]; - - // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain - // the special tokens. - let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]); - let mut new_encoding = Encoding::new( - ids, - type_ids, - tokens, - words, - offsets, - special_tokens, - attention_mask, - encoding - .take_overflowing() - .into_iter() - .map(|encoding| { + let encodings: Vec = encodings + .iter_mut() + .enumerate() + .map(|(i, encoding)| { + if i == 0 { let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); let tokens = [ @@ -135,8 +93,8 @@ impl PostProcessor for RobertaProcessing { [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat(); let attention_mask = vec![1; ids.len()]; - // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't - // contain the special tokens. + // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain + // the special tokens. let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]); Encoding::new( ids, @@ -146,82 +104,118 @@ impl PostProcessor for RobertaProcessing { offsets, special_tokens, attention_mask, - vec![], + encoding + .take_overflowing() + .into_iter() + .map(|encoding| { + let ids = + [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); + let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); + let tokens = [ + &[self.cls.0.clone()], + encoding.get_tokens(), + &[self.sep.0.clone()], + ] + .concat(); + let words = [&[None], encoding.get_word_ids(), &[None]].concat(); + let offsets = + [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); + let special_tokens = + [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]] + .concat(); + let attention_mask = vec![1; ids.len()]; + + // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't + // contain the special tokens. + let sequence_ranges = + HashMap::from_iter(vec![(0, 1..ids.len() - 1)]); + Encoding::new( + ids, + type_ids, + tokens, + words, + offsets, + special_tokens, + attention_mask, + vec![], + sequence_ranges, + ) + }) + .collect(), sequence_ranges, ) - }) - .collect(), - sequence_ranges, - ); + } else { + let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); + let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; + let pair_tokens = [ + &[self.sep.0.clone()], + encoding.get_tokens(), + &[self.sep.0.clone()], + ] + .concat(); + let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat(); + let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); + let pair_special_tokens = + [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); + let pair_attention_mask = vec![1; pair_ids.len()]; - if let Some(mut encoding) = pair_encoding { - let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); - let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; - let pair_tokens = [ - &[self.sep.0.clone()], - encoding.get_tokens(), - &[self.sep.0.clone()], - ] - .concat(); - let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat(); - let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); - let pair_special_tokens = - [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); - let pair_attention_mask = vec![1; pair_ids.len()]; + // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain + // the special tokens. + let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]); + Encoding::new( + pair_ids, + pair_type_ids, + pair_tokens, + pair_words, + pair_offsets, + pair_special_tokens, + pair_attention_mask, + encoding + .take_overflowing() + .into_iter() + .map(|encoding| { + let pair_ids = + [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); + let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; + let pair_tokens = [ + &[self.sep.0.clone()], + encoding.get_tokens(), + &[self.sep.0.clone()], + ] + .concat(); + let pair_words = + [&[None], encoding.get_word_ids(), &[None]].concat(); + let pair_offsets = + [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); + let pair_special_tokens = + [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]] + .concat(); + let pair_attention_mask = vec![1; pair_ids.len()]; - // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain - // the special tokens. - let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]); - let new_pair_encoding = Encoding::new( - pair_ids, - pair_type_ids, - pair_tokens, - pair_words, - pair_offsets, - pair_special_tokens, - pair_attention_mask, - encoding - .take_overflowing() - .into_iter() - .map(|encoding| { - let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); - let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; - let pair_tokens = [ - &[self.sep.0.clone()], - encoding.get_tokens(), - &[self.sep.0.clone()], - ] - .concat(); - let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat(); - let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); - let pair_special_tokens = - [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); - let pair_attention_mask = vec![1; pair_ids.len()]; + // For compatibility with `TemplateProcessing`, the sequence_ranges + // shouldn't contain the special tokens. + let pair_sequence_ranges = + HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]); + Encoding::new( + pair_ids, + pair_type_ids, + pair_tokens, + pair_words, + pair_offsets, + pair_special_tokens, + pair_attention_mask, + vec![], + pair_sequence_ranges, + ) + }) + .collect(), + pair_sequence_ranges, + ) + } + }) + .collect(); - // For compatibility with `TemplateProcessing`, the sequence_ranges - // shouldn't contain the special tokens. - let pair_sequence_ranges = - HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]); - Encoding::new( - pair_ids, - pair_type_ids, - pair_tokens, - pair_words, - pair_offsets, - pair_special_tokens, - pair_attention_mask, - vec![], - pair_sequence_ranges, - ) - }) - .collect(), - pair_sequence_ranges, - ); - - new_encoding.merge_with(new_pair_encoding, false); - } - - Ok(vec![new_encoding]) + Ok(encodings) } } diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 860461af..262323cf 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -55,14 +55,14 @@ //! //! [`TemplateProcessing`]: struct.TemplateProcessing.html //! -use crate::{tokenizer::ProcessorError, Encoding, PostProcessor, Result}; +use crate::{Encoding, PostProcessor, Result}; use itertools::Itertools; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::convert::{TryFrom, TryInto}; use std::result::Result as StdResult; -/// Represents both sequences received as input of the PostProcessor +/// Represents any sequences received as input of the PostProcessor #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] pub enum Sequence { /// This is the first sequence, the one that is always specified @@ -479,145 +479,102 @@ impl TemplateProcessing { fn apply_template( &self, template: &[Piece], - mut encoding: Encoding, - mut pair: Option, + mut encodings: Vec, add_special_tokens: bool, - ) -> Result { - // Compute the new size - let mut new_len = 0; - for piece in template { - new_len += match piece { - Piece::Sequence { - id: Sequence::A, .. - } => encoding.len(), - Piece::Sequence { - id: Sequence::B, .. - } => pair - .as_ref() - .ok_or("Template expected a pair sequence, but none provided")? - .len(), - Piece::SpecialToken { id, .. } => { - if add_special_tokens { - self.special_tokens - .0 - .get(id) - .ok_or_else(|| format!("Missing SpecialToken with id {}", id))? - .ids - .len() - } else { - 0 + ) -> Result> { + let final_encodings: Vec = template + .iter() + .flat_map(|piece| { + match piece { + Piece::Sequence { id, type_id } => { + let i = if *id == Sequence::A { 0 } else { 1 }; + let encoding = &mut encodings[i]; + encoding.set_type_ids(vec![*type_id; encoding.len()]); + encoding.set_sequence_id(i); + Some(encoding.clone()) + } + Piece::SpecialToken { id, type_id } => { + if add_special_tokens { + let tok = &self.special_tokens.0[id]; // We already checked existance above + let len = tok.ids.len(); + + let encoding = Encoding::new( + tok.ids.clone(), + std::iter::repeat(*type_id).take(len).collect(), + tok.tokens.clone(), + // words + std::iter::repeat(None).take(len).collect(), + // offsets + std::iter::repeat((0, 0)).take(len).collect(), + // special_tokens_mask + std::iter::repeat(1).take(len).collect(), + // attention_mask + std::iter::repeat(1).take(len).collect(), + // overflowing + vec![], + // sequence_range + HashMap::new(), + ); + Some(encoding) + } else { + None + } } } - }; - } - - // Then build the new Encoding - let mut ids = Vec::with_capacity(new_len); - let mut type_ids = Vec::with_capacity(new_len); - let mut tokens = Vec::with_capacity(new_len); - let mut words = Vec::with_capacity(new_len); - let mut offsets = Vec::with_capacity(new_len); - let mut special_tokens_mask = Vec::with_capacity(new_len); - let mut attention_mask = Vec::with_capacity(new_len); - let mut sequence_ranges = HashMap::new(); - - let pair_overflowing = pair.as_mut().map_or(vec![], |e| e.take_overflowing()); - let mut overflowing = encoding - .take_overflowing() - .into_iter() - .flat_map(|encoding| { - // 1. The pair itself - let mut overflowings = vec![self.apply_template( - template, - encoding.clone(), - pair.clone(), - add_special_tokens, - )]; - - // 2. Its overflowings - for other_o in &pair_overflowing { - overflowings.push(self.apply_template( - template, - encoding.clone(), - Some(other_o.clone()), - add_special_tokens, - )); - } - - overflowings }) - .collect::>>()?; - // We also need to combine the first sequence with all other overflowings - overflowing.extend( - pair_overflowing - .into_iter() - .map(|pair| { - self.apply_template(template, encoding.clone(), Some(pair), add_special_tokens) - }) - .collect::>>()?, - ); + .collect(); - for piece in template { - match piece { - Piece::Sequence { - id: Sequence::A, - type_id, - } => { - let seq_start = ids.len(); - let seq_end = seq_start + encoding.len(); - sequence_ranges.insert(0, seq_start..seq_end); - ids.extend(encoding.get_ids()); - type_ids.extend(std::iter::repeat(type_id).take(encoding.len())); - tokens.extend(encoding.get_tokens().iter().map(|s| s.to_owned())); - words.extend(encoding.get_word_ids()); - offsets.extend(encoding.get_offsets()); - special_tokens_mask.extend(encoding.get_special_tokens_mask()); - attention_mask.extend(encoding.get_attention_mask()); - } - Piece::Sequence { - id: Sequence::B, - type_id, - } => { - let pair = pair.as_ref().expect("Missing pair sequence, checked above"); - let seq_start = ids.len(); - let seq_end = seq_start + pair.len(); - sequence_ranges.insert(1, seq_start..seq_end); - ids.extend(pair.get_ids()); - type_ids.extend(std::iter::repeat(type_id).take(pair.len())); - tokens.extend(pair.get_tokens().iter().map(|s| s.to_owned())); - words.extend(pair.get_word_ids()); - offsets.extend(pair.get_offsets()); - special_tokens_mask.extend(pair.get_special_tokens_mask()); - attention_mask.extend(pair.get_attention_mask()); - } - Piece::SpecialToken { id, type_id } => { - if add_special_tokens { - let tok = &self.special_tokens.0[id]; // We already checked existance above - let len = tok.ids.len(); + //let mut pair = if encodings.len() > 1 { + // Some(encodings.pop().unwrap()) + //} else { + // None + //}; + //let mut encoding = encodings.pop().unwrap(); - ids.extend(&tok.ids); - type_ids.extend(std::iter::repeat(type_id).take(len)); - tokens.extend(tok.tokens.clone()); - words.extend(std::iter::repeat(None).take(len)); - offsets.extend(std::iter::repeat((0, 0)).take(len)); - special_tokens_mask.extend(std::iter::repeat(1).take(len)); - attention_mask.extend(std::iter::repeat(1).take(len)); - } - } - } - } + //let pair_overflowing = pair.as_mut().map_or(vec![], |e| e.take_overflowing()); + //let mut overflowing: Vec = encoding + // .take_overflowing() + // .iter() + // .map(|encoding| -> Result> { + // // 1. The pair itself + // let mut overflowings = self.apply_template( + // template, + // if encodings.len() > 1 { + // vec![encoding.clone(), encodings[1].clone()] + // } else { + // vec![encoding.clone()] + // }, + // add_special_tokens, + // )?; - Ok(Encoding::new( - ids, - type_ids, - tokens, - words, - offsets, - special_tokens_mask, - attention_mask, - overflowing, - sequence_ranges, - )) + // // 2. Its overflowings + // for other_o in &pair_overflowing { + // overflowings.extend(self.apply_template( + // template, + // vec![encoding.clone(), other_o.clone()], + // add_special_tokens, + // )?); + // } + + // Ok(overflowings) + // }) + // .collect::>>>()? + // .into_iter() + // .flatten() + // .collect(); + //// We also need to combine the first sequence with all other overflowings + //overflowing.extend( + // pair_overflowing + // .into_iter() + // .map(|pair| { + // self.apply_template(template, vec![encoding.clone(), pair], add_special_tokens) + // }) + // .collect::>>()? + // .into_iter() + // .flatten(), + //); + + Ok(final_encodings) } } @@ -632,39 +589,34 @@ impl PostProcessor for TemplateProcessing { fn process_encodings( &self, - mut encodings: Vec, + encodings: Vec, add_special_tokens: bool, ) -> Result> { - let (encoding, pair): (Encoding, Option) = match encodings.len() { - 1 => ( - encodings - .pop() - .ok_or(ProcessorError::InvalidEncodingsVecLength)?, - None, - ), - 2 => { - let pair = encodings - .pop() - .ok_or(ProcessorError::InvalidEncodingsVecLength)?; - let encoding = encodings - .pop() - .ok_or(ProcessorError::InvalidEncodingsVecLength)?; - (encoding, Some(pair)) - } - _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)), + // let (encoding, pair): (Encoding, Option) = match encodings.len() { + // 1 => ( + // encodings + // .pop() + // .ok_or(ProcessorError::InvalidEncodingsVecLength)?, + // None, + // ), + // 2 => { + // let pair = encodings + // .pop() + // .ok_or(ProcessorError::InvalidEncodingsVecLength)?; + // let encoding = encodings + // .pop() + // .ok_or(ProcessorError::InvalidEncodingsVecLength)?; + // (encoding, Some(pair)) + // } + // _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)), + // }; + let template = match encodings.len() { + 2 => &self.pair.0, + 1 => &self.single.0, + _ => todo!(), }; - - let encoding = self.apply_template( - if pair.is_some() { - &self.pair.0 - } else { - &self.single.0 - }, - encoding, - pair, - add_special_tokens, - )?; - Ok(vec![encoding]) + let encodings = self.apply_template(template, encodings, add_special_tokens)?; + Ok(encodings) } } @@ -884,7 +836,6 @@ mod tests { ); let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0); let single_encoding = processor.process(encoding.clone(), None, true).unwrap(); - let pair_encoding = processor.process(encoding, Some(pair), true).unwrap(); assert_eq!( single_encoding, Encoding::new( @@ -906,6 +857,7 @@ mod tests { ); assert_eq!(single_encoding.token_to_sequence(2), Some(0)); assert_eq!(single_encoding.token_to_sequence(3), None); + let pair_encoding = processor.process(encoding, Some(pair), true).unwrap(); assert_eq!( pair_encoding, Encoding::new( diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index 3098e9b1..b1b4e03c 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -152,6 +152,10 @@ impl Encoding { &self.type_ids } + pub fn set_type_ids(&mut self, type_ids: Vec) { + self.type_ids = type_ids; + } + pub fn get_offsets(&self) -> &[Offsets] { &self.offsets } @@ -383,6 +387,12 @@ impl Encoding { pub fn merge>(encodings: I, growing_offsets: bool) -> Self { let mut encoding = Encoding::default(); + // TODO this is suboptimal as we're doing this iteratively instead of preallocating + // all the encodings sizes all at once and only copying into this preallocated vector + // https://github.com/huggingface/tokenizers/pull/1049 + + // In order to fix, we just need to preallocate all vectors, then copy everything + // into it (and deal with overlowings correctly) for sub in encodings { encoding.merge_with(sub, growing_offsets); }