Making process_encodings not eat up the encodings any more. (#1051)

* Making `process_encodings` not eat up the encodings any more.

* Fixing clippy.
This commit is contained in:
Nicolas Patry
2022-08-25 11:49:18 +02:00
committed by GitHub
parent c174b5bd34
commit 37f7bae0f7
6 changed files with 424 additions and 397 deletions

View File

@ -34,6 +34,10 @@ harness = false
name = "bert_benchmark" name = "bert_benchmark"
harness = false harness = false
[[bench]]
name = "layout_benchmark"
harness = false
[dependencies] [dependencies]
lazy_static = "1.4" lazy_static = "1.4"
rand = "0.8" rand = "0.8"

View File

@ -0,0 +1,76 @@
#[macro_use]
extern crate criterion;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::time::{Duration, Instant};
use criterion::black_box;
use criterion::Criterion;
use tokenizers::processors::template::TemplateProcessing;
use tokenizers::{EncodeInput, Encoding, PostProcessor, Tokenizer};
/// Simple TemplateProcessing
fn create_processor() -> TemplateProcessing {
TemplateProcessing::builder()
.try_single("[CLS]:0 $A:0 [SEP]:0")
.unwrap()
.try_pair("[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1")
.unwrap()
.special_tokens(vec![("[CLS]", 0), ("[SEP]", 1)])
.build()
.unwrap()
}
pub fn bench_layout(c: &mut Criterion) {
let processor = create_processor();
let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap();
let mut encodeds: Vec<Encoding> = vec![];
for line in BufReader::new(File::open(Path::new("data/big.txt")).unwrap()).lines() {
let line: EncodeInput = line.unwrap().into();
let encoded: Encoding = tokenizer.encode(line, false).unwrap();
encodeds.push(encoded);
}
c.bench_function("TemplateProcessing single encode", |b| {
b.iter_custom(|iters| {
let mut duration = Duration::new(0, 0);
for i in 0..iters as usize {
let encoded_index = i % encodeds.len();
let encoded: Encoding = encodeds[encoded_index].clone();
let start = Instant::now();
let _ = black_box(processor.process(encoded, None, false));
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
})
});
c.bench_function("TemplateProcessing pair encode", |b| {
b.iter_custom(|iters| {
let mut duration = Duration::new(0, 0);
for i in 0..iters as usize {
let encoded_index = i % encodeds.len();
let encoded: Encoding = encodeds[encoded_index].clone();
let encoded_index2 = (i + 1) % encodeds.len();
let pair: Encoding = encodeds[encoded_index2].clone();
let start = Instant::now();
let _ = black_box(processor.process(encoded, Some(pair), false));
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
})
});
}
criterion_group! {
name = layout_benches;
config = Criterion::default().sample_size(20);
targets = bench_layout
}
criterion_main!(layout_benches);

View File

@ -1,4 +1,4 @@
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result}; use crate::tokenizer::{Encoding, PostProcessor, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::iter::FromIterator; use std::iter::FromIterator;
@ -49,53 +49,11 @@ impl PostProcessor for BertProcessing {
return Ok(encodings); return Ok(encodings);
} }
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() { let encodings: Vec<Encoding> = encodings
1 => ( .iter_mut()
encodings .enumerate()
.pop() .map(|(i, encoding)| {
.ok_or(ProcessorError::InvalidEncodingsVecLength)?, if i == 0 {
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
&[self.cls.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
let mut new_encoding = Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens,
attention_mask,
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [ let tokens = [
@ -110,8 +68,8 @@ impl PostProcessor for BertProcessing {
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat(); [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()]; let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// contain the special tokens. // the special tokens.
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]); let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
Encoding::new( Encoding::new(
ids, ids,
@ -121,72 +79,105 @@ impl PostProcessor for BertProcessing {
offsets, offsets,
special_tokens, special_tokens,
attention_mask, attention_mask,
vec![], encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let ids =
[&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
&[self.cls.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
let offsets =
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let special_tokens =
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
.concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
// contain the special tokens.
let sequence_ranges =
HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens,
attention_mask,
vec![],
sequence_ranges,
)
})
.collect(),
sequence_ranges, sequence_ranges,
) )
}) } else {
.collect(), let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
sequence_ranges, let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
); let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
let pair_words = [encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
if let Some(mut encoding) = pair_encoding { // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat(); // the special tokens.
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat(); let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat(); Encoding::new(
let pair_words = [encoding.get_word_ids(), &[None]].concat(); pair_ids,
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat(); pair_type_ids,
let pair_special_tokens = pair_tokens,
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); pair_words,
let pair_attention_mask = vec![1; pair_ids.len()]; pair_offsets,
pair_special_tokens,
pair_attention_mask,
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
let pair_tokens =
[encoding.get_tokens(), &[self.sep.0.clone()]].concat();
let pair_words = [encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain // For compatibility with `TemplateProcessing`, the sequence_ranges
// the special tokens. // shouldn't contain the special tokens.
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]); let pair_sequence_ranges =
let new_pair_encoding = Encoding::new( HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
pair_ids, Encoding::new(
pair_type_ids, pair_ids,
pair_tokens, pair_type_ids,
pair_words, pair_tokens,
pair_offsets, pair_words,
pair_special_tokens, pair_offsets,
pair_attention_mask, pair_special_tokens,
encoding pair_attention_mask,
.take_overflowing() vec![],
.into_iter() pair_sequence_ranges,
.map(|encoding| { )
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat(); })
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat(); .collect(),
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat(); pair_sequence_ranges,
let pair_words = [encoding.get_word_ids(), &[None]].concat(); )
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat(); }
let pair_special_tokens = })
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); .collect();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges Ok(encodings)
// shouldn't contain the special tokens.
let pair_sequence_ranges =
HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
vec![],
pair_sequence_ranges,
)
})
.collect(),
pair_sequence_ranges,
);
new_encoding.merge_with(new_pair_encoding, false);
}
Ok(vec![new_encoding])
} }
} }

View File

@ -1,5 +1,5 @@
use crate::processors::byte_level::process_offsets; use crate::processors::byte_level::process_offsets;
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result}; use crate::tokenizer::{Encoding, PostProcessor, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::iter::FromIterator; use std::iter::FromIterator;
@ -74,53 +74,11 @@ impl PostProcessor for RobertaProcessing {
return Ok(encodings); return Ok(encodings);
} }
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() { let encodings: Vec<Encoding> = encodings
1 => ( .iter_mut()
encodings .enumerate()
.pop() .map(|(i, encoding)| {
.ok_or(ProcessorError::InvalidEncodingsVecLength)?, if i == 0 {
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
&[self.cls.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
let mut new_encoding = Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens,
attention_mask,
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [ let tokens = [
@ -135,8 +93,8 @@ impl PostProcessor for RobertaProcessing {
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat(); [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()]; let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// contain the special tokens. // the special tokens.
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]); let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
Encoding::new( Encoding::new(
ids, ids,
@ -146,82 +104,118 @@ impl PostProcessor for RobertaProcessing {
offsets, offsets,
special_tokens, special_tokens,
attention_mask, attention_mask,
vec![], encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let ids =
[&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
&[self.cls.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
let offsets =
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let special_tokens =
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
.concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
// contain the special tokens.
let sequence_ranges =
HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens,
attention_mask,
vec![],
sequence_ranges,
)
})
.collect(),
sequence_ranges, sequence_ranges,
) )
}) } else {
.collect(), let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
sequence_ranges, let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
); let pair_tokens = [
&[self.sep.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
if let Some(mut encoding) = pair_encoding { // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); // the special tokens.
let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
let pair_tokens = [ Encoding::new(
&[self.sep.0.clone()], pair_ids,
encoding.get_tokens(), pair_type_ids,
&[self.sep.0.clone()], pair_tokens,
] pair_words,
.concat(); pair_offsets,
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat(); pair_special_tokens,
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); pair_attention_mask,
let pair_special_tokens = encoding
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); .take_overflowing()
let pair_attention_mask = vec![1; pair_ids.len()]; .into_iter()
.map(|encoding| {
let pair_ids =
[&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
let pair_tokens = [
&[self.sep.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let pair_words =
[&[None], encoding.get_word_ids(), &[None]].concat();
let pair_offsets =
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]]
.concat();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain // For compatibility with `TemplateProcessing`, the sequence_ranges
// the special tokens. // shouldn't contain the special tokens.
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]); let pair_sequence_ranges =
let new_pair_encoding = Encoding::new( HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
pair_ids, Encoding::new(
pair_type_ids, pair_ids,
pair_tokens, pair_type_ids,
pair_words, pair_tokens,
pair_offsets, pair_words,
pair_special_tokens, pair_offsets,
pair_attention_mask, pair_special_tokens,
encoding pair_attention_mask,
.take_overflowing() vec![],
.into_iter() pair_sequence_ranges,
.map(|encoding| { )
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); })
let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; .collect(),
let pair_tokens = [ pair_sequence_ranges,
&[self.sep.0.clone()], )
encoding.get_tokens(), }
&[self.sep.0.clone()], })
] .collect();
.concat();
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges Ok(encodings)
// shouldn't contain the special tokens.
let pair_sequence_ranges =
HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
vec![],
pair_sequence_ranges,
)
})
.collect(),
pair_sequence_ranges,
);
new_encoding.merge_with(new_pair_encoding, false);
}
Ok(vec![new_encoding])
} }
} }

View File

@ -55,14 +55,14 @@
//! //!
//! [`TemplateProcessing`]: struct.TemplateProcessing.html //! [`TemplateProcessing`]: struct.TemplateProcessing.html
//! //!
use crate::{tokenizer::ProcessorError, Encoding, PostProcessor, Result}; use crate::{Encoding, PostProcessor, Result};
use itertools::Itertools; use itertools::Itertools;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::convert::{TryFrom, TryInto}; use std::convert::{TryFrom, TryInto};
use std::result::Result as StdResult; use std::result::Result as StdResult;
/// Represents both sequences received as input of the PostProcessor /// Represents any sequences received as input of the PostProcessor
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
pub enum Sequence { pub enum Sequence {
/// This is the first sequence, the one that is always specified /// This is the first sequence, the one that is always specified
@ -479,145 +479,102 @@ impl TemplateProcessing {
fn apply_template( fn apply_template(
&self, &self,
template: &[Piece], template: &[Piece],
mut encoding: Encoding, mut encodings: Vec<Encoding>,
mut pair: Option<Encoding>,
add_special_tokens: bool, add_special_tokens: bool,
) -> Result<Encoding> { ) -> Result<Vec<Encoding>> {
// Compute the new size let final_encodings: Vec<Encoding> = template
let mut new_len = 0; .iter()
for piece in template { .flat_map(|piece| {
new_len += match piece { match piece {
Piece::Sequence { Piece::Sequence { id, type_id } => {
id: Sequence::A, .. let i = if *id == Sequence::A { 0 } else { 1 };
} => encoding.len(), let encoding = &mut encodings[i];
Piece::Sequence { encoding.set_type_ids(vec![*type_id; encoding.len()]);
id: Sequence::B, .. encoding.set_sequence_id(i);
} => pair Some(encoding.clone())
.as_ref() }
.ok_or("Template expected a pair sequence, but none provided")? Piece::SpecialToken { id, type_id } => {
.len(), if add_special_tokens {
Piece::SpecialToken { id, .. } => { let tok = &self.special_tokens.0[id]; // We already checked existance above
if add_special_tokens { let len = tok.ids.len();
self.special_tokens
.0 let encoding = Encoding::new(
.get(id) tok.ids.clone(),
.ok_or_else(|| format!("Missing SpecialToken with id {}", id))? std::iter::repeat(*type_id).take(len).collect(),
.ids tok.tokens.clone(),
.len() // words
} else { std::iter::repeat(None).take(len).collect(),
0 // offsets
std::iter::repeat((0, 0)).take(len).collect(),
// special_tokens_mask
std::iter::repeat(1).take(len).collect(),
// attention_mask
std::iter::repeat(1).take(len).collect(),
// overflowing
vec![],
// sequence_range
HashMap::new(),
);
Some(encoding)
} else {
None
}
} }
} }
};
}
// Then build the new Encoding
let mut ids = Vec::with_capacity(new_len);
let mut type_ids = Vec::with_capacity(new_len);
let mut tokens = Vec::with_capacity(new_len);
let mut words = Vec::with_capacity(new_len);
let mut offsets = Vec::with_capacity(new_len);
let mut special_tokens_mask = Vec::with_capacity(new_len);
let mut attention_mask = Vec::with_capacity(new_len);
let mut sequence_ranges = HashMap::new();
let pair_overflowing = pair.as_mut().map_or(vec![], |e| e.take_overflowing());
let mut overflowing = encoding
.take_overflowing()
.into_iter()
.flat_map(|encoding| {
// 1. The pair itself
let mut overflowings = vec![self.apply_template(
template,
encoding.clone(),
pair.clone(),
add_special_tokens,
)];
// 2. Its overflowings
for other_o in &pair_overflowing {
overflowings.push(self.apply_template(
template,
encoding.clone(),
Some(other_o.clone()),
add_special_tokens,
));
}
overflowings
}) })
.collect::<Result<Vec<_>>>()?; .collect();
// We also need to combine the first sequence with all other overflowings
overflowing.extend(
pair_overflowing
.into_iter()
.map(|pair| {
self.apply_template(template, encoding.clone(), Some(pair), add_special_tokens)
})
.collect::<Result<Vec<_>>>()?,
);
for piece in template { //let mut pair = if encodings.len() > 1 {
match piece { // Some(encodings.pop().unwrap())
Piece::Sequence { //} else {
id: Sequence::A, // None
type_id, //};
} => { //let mut encoding = encodings.pop().unwrap();
let seq_start = ids.len();
let seq_end = seq_start + encoding.len();
sequence_ranges.insert(0, seq_start..seq_end);
ids.extend(encoding.get_ids());
type_ids.extend(std::iter::repeat(type_id).take(encoding.len()));
tokens.extend(encoding.get_tokens().iter().map(|s| s.to_owned()));
words.extend(encoding.get_word_ids());
offsets.extend(encoding.get_offsets());
special_tokens_mask.extend(encoding.get_special_tokens_mask());
attention_mask.extend(encoding.get_attention_mask());
}
Piece::Sequence {
id: Sequence::B,
type_id,
} => {
let pair = pair.as_ref().expect("Missing pair sequence, checked above");
let seq_start = ids.len();
let seq_end = seq_start + pair.len();
sequence_ranges.insert(1, seq_start..seq_end);
ids.extend(pair.get_ids());
type_ids.extend(std::iter::repeat(type_id).take(pair.len()));
tokens.extend(pair.get_tokens().iter().map(|s| s.to_owned()));
words.extend(pair.get_word_ids());
offsets.extend(pair.get_offsets());
special_tokens_mask.extend(pair.get_special_tokens_mask());
attention_mask.extend(pair.get_attention_mask());
}
Piece::SpecialToken { id, type_id } => {
if add_special_tokens {
let tok = &self.special_tokens.0[id]; // We already checked existance above
let len = tok.ids.len();
ids.extend(&tok.ids); //let pair_overflowing = pair.as_mut().map_or(vec![], |e| e.take_overflowing());
type_ids.extend(std::iter::repeat(type_id).take(len)); //let mut overflowing: Vec<Encoding> = encoding
tokens.extend(tok.tokens.clone()); // .take_overflowing()
words.extend(std::iter::repeat(None).take(len)); // .iter()
offsets.extend(std::iter::repeat((0, 0)).take(len)); // .map(|encoding| -> Result<Vec<Encoding>> {
special_tokens_mask.extend(std::iter::repeat(1).take(len)); // // 1. The pair itself
attention_mask.extend(std::iter::repeat(1).take(len)); // let mut overflowings = self.apply_template(
} // template,
} // if encodings.len() > 1 {
} // vec![encoding.clone(), encodings[1].clone()]
} // } else {
// vec![encoding.clone()]
// },
// add_special_tokens,
// )?;
Ok(Encoding::new( // // 2. Its overflowings
ids, // for other_o in &pair_overflowing {
type_ids, // overflowings.extend(self.apply_template(
tokens, // template,
words, // vec![encoding.clone(), other_o.clone()],
offsets, // add_special_tokens,
special_tokens_mask, // )?);
attention_mask, // }
overflowing,
sequence_ranges, // Ok(overflowings)
)) // })
// .collect::<Result<Vec<Vec<Encoding>>>>()?
// .into_iter()
// .flatten()
// .collect();
//// We also need to combine the first sequence with all other overflowings
//overflowing.extend(
// pair_overflowing
// .into_iter()
// .map(|pair| {
// self.apply_template(template, vec![encoding.clone(), pair], add_special_tokens)
// })
// .collect::<Result<Vec<_>>>()?
// .into_iter()
// .flatten(),
//);
Ok(final_encodings)
} }
} }
@ -632,39 +589,34 @@ impl PostProcessor for TemplateProcessing {
fn process_encodings( fn process_encodings(
&self, &self,
mut encodings: Vec<Encoding>, encodings: Vec<Encoding>,
add_special_tokens: bool, add_special_tokens: bool,
) -> Result<Vec<Encoding>> { ) -> Result<Vec<Encoding>> {
let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() { // let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
1 => ( // 1 => (
encodings // encodings
.pop() // .pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?, // .ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None, // None,
), // ),
2 => { // 2 => {
let pair = encodings // let pair = encodings
.pop() // .pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?; // .ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings // let encoding = encodings
.pop() // .pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?; // .ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair)) // (encoding, Some(pair))
} // }
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)), // _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
// };
let template = match encodings.len() {
2 => &self.pair.0,
1 => &self.single.0,
_ => todo!(),
}; };
let encodings = self.apply_template(template, encodings, add_special_tokens)?;
let encoding = self.apply_template( Ok(encodings)
if pair.is_some() {
&self.pair.0
} else {
&self.single.0
},
encoding,
pair,
add_special_tokens,
)?;
Ok(vec![encoding])
} }
} }
@ -884,7 +836,6 @@ mod tests {
); );
let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0); let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
let single_encoding = processor.process(encoding.clone(), None, true).unwrap(); let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
assert_eq!( assert_eq!(
single_encoding, single_encoding,
Encoding::new( Encoding::new(
@ -906,6 +857,7 @@ mod tests {
); );
assert_eq!(single_encoding.token_to_sequence(2), Some(0)); assert_eq!(single_encoding.token_to_sequence(2), Some(0));
assert_eq!(single_encoding.token_to_sequence(3), None); assert_eq!(single_encoding.token_to_sequence(3), None);
let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
assert_eq!( assert_eq!(
pair_encoding, pair_encoding,
Encoding::new( Encoding::new(

View File

@ -152,6 +152,10 @@ impl Encoding {
&self.type_ids &self.type_ids
} }
pub fn set_type_ids(&mut self, type_ids: Vec<u32>) {
self.type_ids = type_ids;
}
pub fn get_offsets(&self) -> &[Offsets] { pub fn get_offsets(&self) -> &[Offsets] {
&self.offsets &self.offsets
} }
@ -383,6 +387,12 @@ impl Encoding {
pub fn merge<I: IntoIterator<Item = Encoding>>(encodings: I, growing_offsets: bool) -> Self { pub fn merge<I: IntoIterator<Item = Encoding>>(encodings: I, growing_offsets: bool) -> Self {
let mut encoding = Encoding::default(); let mut encoding = Encoding::default();
// TODO this is suboptimal as we're doing this iteratively instead of preallocating
// all the encodings sizes all at once and only copying into this preallocated vector
// https://github.com/huggingface/tokenizers/pull/1049
// In order to fix, we just need to preallocate all vectors, then copy everything
// into it (and deal with overlowings correctly)
for sub in encodings { for sub in encodings {
encoding.merge_with(sub, growing_offsets); encoding.merge_with(sub, growing_offsets);
} }