Making process_encodings not eat up the encodings any more. (#1051)

* Making `process_encodings` not eat up the encodings any more.

* Fixing clippy.
This commit is contained in:
Nicolas Patry
2022-08-25 11:49:18 +02:00
committed by GitHub
parent c174b5bd34
commit 37f7bae0f7
6 changed files with 424 additions and 397 deletions

View File

@ -34,6 +34,10 @@ harness = false
name = "bert_benchmark"
harness = false
[[bench]]
name = "layout_benchmark"
harness = false
[dependencies]
lazy_static = "1.4"
rand = "0.8"

View File

@ -0,0 +1,76 @@
#[macro_use]
extern crate criterion;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::time::{Duration, Instant};
use criterion::black_box;
use criterion::Criterion;
use tokenizers::processors::template::TemplateProcessing;
use tokenizers::{EncodeInput, Encoding, PostProcessor, Tokenizer};
/// Simple TemplateProcessing
fn create_processor() -> TemplateProcessing {
TemplateProcessing::builder()
.try_single("[CLS]:0 $A:0 [SEP]:0")
.unwrap()
.try_pair("[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1")
.unwrap()
.special_tokens(vec![("[CLS]", 0), ("[SEP]", 1)])
.build()
.unwrap()
}
pub fn bench_layout(c: &mut Criterion) {
let processor = create_processor();
let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap();
let mut encodeds: Vec<Encoding> = vec![];
for line in BufReader::new(File::open(Path::new("data/big.txt")).unwrap()).lines() {
let line: EncodeInput = line.unwrap().into();
let encoded: Encoding = tokenizer.encode(line, false).unwrap();
encodeds.push(encoded);
}
c.bench_function("TemplateProcessing single encode", |b| {
b.iter_custom(|iters| {
let mut duration = Duration::new(0, 0);
for i in 0..iters as usize {
let encoded_index = i % encodeds.len();
let encoded: Encoding = encodeds[encoded_index].clone();
let start = Instant::now();
let _ = black_box(processor.process(encoded, None, false));
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
})
});
c.bench_function("TemplateProcessing pair encode", |b| {
b.iter_custom(|iters| {
let mut duration = Duration::new(0, 0);
for i in 0..iters as usize {
let encoded_index = i % encodeds.len();
let encoded: Encoding = encodeds[encoded_index].clone();
let encoded_index2 = (i + 1) % encodeds.len();
let pair: Encoding = encodeds[encoded_index2].clone();
let start = Instant::now();
let _ = black_box(processor.process(encoded, Some(pair), false));
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
})
});
}
criterion_group! {
name = layout_benches;
config = Criterion::default().sample_size(20);
targets = bench_layout
}
criterion_main!(layout_benches);

View File

@ -1,4 +1,4 @@
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
use crate::tokenizer::{Encoding, PostProcessor, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::iter::FromIterator;
@ -49,53 +49,11 @@ impl PostProcessor for BertProcessing {
return Ok(encodings);
}
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
&[self.cls.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
let mut new_encoding = Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens,
attention_mask,
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let encodings: Vec<Encoding> = encodings
.iter_mut()
.enumerate()
.map(|(i, encoding)| {
if i == 0 {
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
@ -110,8 +68,8 @@ impl PostProcessor for BertProcessing {
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
// contain the special tokens.
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
Encoding::new(
ids,
@ -121,72 +79,105 @@ impl PostProcessor for BertProcessing {
offsets,
special_tokens,
attention_mask,
vec![],
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let ids =
[&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
&[self.cls.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
let offsets =
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let special_tokens =
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
.concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
// contain the special tokens.
let sequence_ranges =
HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens,
attention_mask,
vec![],
sequence_ranges,
)
})
.collect(),
sequence_ranges,
)
})
.collect(),
sequence_ranges,
);
} else {
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
let pair_words = [encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
if let Some(mut encoding) = pair_encoding {
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
let pair_words = [encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
let pair_tokens =
[encoding.get_tokens(), &[self.sep.0.clone()]].concat();
let pair_words = [encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
let new_pair_encoding = Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
let pair_words = [encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges
// shouldn't contain the special tokens.
let pair_sequence_ranges =
HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
vec![],
pair_sequence_ranges,
)
})
.collect(),
pair_sequence_ranges,
)
}
})
.collect();
// For compatibility with `TemplateProcessing`, the sequence_ranges
// shouldn't contain the special tokens.
let pair_sequence_ranges =
HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
vec![],
pair_sequence_ranges,
)
})
.collect(),
pair_sequence_ranges,
);
new_encoding.merge_with(new_pair_encoding, false);
}
Ok(vec![new_encoding])
Ok(encodings)
}
}

View File

@ -1,5 +1,5 @@
use crate::processors::byte_level::process_offsets;
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
use crate::tokenizer::{Encoding, PostProcessor, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::iter::FromIterator;
@ -74,53 +74,11 @@ impl PostProcessor for RobertaProcessing {
return Ok(encodings);
}
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
&[self.cls.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
let mut new_encoding = Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens,
attention_mask,
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let encodings: Vec<Encoding> = encodings
.iter_mut()
.enumerate()
.map(|(i, encoding)| {
if i == 0 {
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
@ -135,8 +93,8 @@ impl PostProcessor for RobertaProcessing {
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
// contain the special tokens.
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
Encoding::new(
ids,
@ -146,82 +104,118 @@ impl PostProcessor for RobertaProcessing {
offsets,
special_tokens,
attention_mask,
vec![],
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let ids =
[&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
&[self.cls.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
let offsets =
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let special_tokens =
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
.concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
// contain the special tokens.
let sequence_ranges =
HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens,
attention_mask,
vec![],
sequence_ranges,
)
})
.collect(),
sequence_ranges,
)
})
.collect(),
sequence_ranges,
);
} else {
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
let pair_tokens = [
&[self.sep.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
if let Some(mut encoding) = pair_encoding {
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
let pair_tokens = [
&[self.sep.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let pair_ids =
[&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
let pair_tokens = [
&[self.sep.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let pair_words =
[&[None], encoding.get_word_ids(), &[None]].concat();
let pair_offsets =
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]]
.concat();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
let new_pair_encoding = Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
let pair_tokens = [
&[self.sep.0.clone()],
encoding.get_tokens(),
&[self.sep.0.clone()],
]
.concat();
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges
// shouldn't contain the special tokens.
let pair_sequence_ranges =
HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
vec![],
pair_sequence_ranges,
)
})
.collect(),
pair_sequence_ranges,
)
}
})
.collect();
// For compatibility with `TemplateProcessing`, the sequence_ranges
// shouldn't contain the special tokens.
let pair_sequence_ranges =
HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
vec![],
pair_sequence_ranges,
)
})
.collect(),
pair_sequence_ranges,
);
new_encoding.merge_with(new_pair_encoding, false);
}
Ok(vec![new_encoding])
Ok(encodings)
}
}

View File

@ -55,14 +55,14 @@
//!
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
//!
use crate::{tokenizer::ProcessorError, Encoding, PostProcessor, Result};
use crate::{Encoding, PostProcessor, Result};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::convert::{TryFrom, TryInto};
use std::result::Result as StdResult;
/// Represents both sequences received as input of the PostProcessor
/// Represents any sequences received as input of the PostProcessor
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
pub enum Sequence {
/// This is the first sequence, the one that is always specified
@ -479,145 +479,102 @@ impl TemplateProcessing {
fn apply_template(
&self,
template: &[Piece],
mut encoding: Encoding,
mut pair: Option<Encoding>,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
// Compute the new size
let mut new_len = 0;
for piece in template {
new_len += match piece {
Piece::Sequence {
id: Sequence::A, ..
} => encoding.len(),
Piece::Sequence {
id: Sequence::B, ..
} => pair
.as_ref()
.ok_or("Template expected a pair sequence, but none provided")?
.len(),
Piece::SpecialToken { id, .. } => {
if add_special_tokens {
self.special_tokens
.0
.get(id)
.ok_or_else(|| format!("Missing SpecialToken with id {}", id))?
.ids
.len()
} else {
0
) -> Result<Vec<Encoding>> {
let final_encodings: Vec<Encoding> = template
.iter()
.flat_map(|piece| {
match piece {
Piece::Sequence { id, type_id } => {
let i = if *id == Sequence::A { 0 } else { 1 };
let encoding = &mut encodings[i];
encoding.set_type_ids(vec![*type_id; encoding.len()]);
encoding.set_sequence_id(i);
Some(encoding.clone())
}
Piece::SpecialToken { id, type_id } => {
if add_special_tokens {
let tok = &self.special_tokens.0[id]; // We already checked existance above
let len = tok.ids.len();
let encoding = Encoding::new(
tok.ids.clone(),
std::iter::repeat(*type_id).take(len).collect(),
tok.tokens.clone(),
// words
std::iter::repeat(None).take(len).collect(),
// offsets
std::iter::repeat((0, 0)).take(len).collect(),
// special_tokens_mask
std::iter::repeat(1).take(len).collect(),
// attention_mask
std::iter::repeat(1).take(len).collect(),
// overflowing
vec![],
// sequence_range
HashMap::new(),
);
Some(encoding)
} else {
None
}
}
}
};
}
// Then build the new Encoding
let mut ids = Vec::with_capacity(new_len);
let mut type_ids = Vec::with_capacity(new_len);
let mut tokens = Vec::with_capacity(new_len);
let mut words = Vec::with_capacity(new_len);
let mut offsets = Vec::with_capacity(new_len);
let mut special_tokens_mask = Vec::with_capacity(new_len);
let mut attention_mask = Vec::with_capacity(new_len);
let mut sequence_ranges = HashMap::new();
let pair_overflowing = pair.as_mut().map_or(vec![], |e| e.take_overflowing());
let mut overflowing = encoding
.take_overflowing()
.into_iter()
.flat_map(|encoding| {
// 1. The pair itself
let mut overflowings = vec![self.apply_template(
template,
encoding.clone(),
pair.clone(),
add_special_tokens,
)];
// 2. Its overflowings
for other_o in &pair_overflowing {
overflowings.push(self.apply_template(
template,
encoding.clone(),
Some(other_o.clone()),
add_special_tokens,
));
}
overflowings
})
.collect::<Result<Vec<_>>>()?;
// We also need to combine the first sequence with all other overflowings
overflowing.extend(
pair_overflowing
.into_iter()
.map(|pair| {
self.apply_template(template, encoding.clone(), Some(pair), add_special_tokens)
})
.collect::<Result<Vec<_>>>()?,
);
.collect();
for piece in template {
match piece {
Piece::Sequence {
id: Sequence::A,
type_id,
} => {
let seq_start = ids.len();
let seq_end = seq_start + encoding.len();
sequence_ranges.insert(0, seq_start..seq_end);
ids.extend(encoding.get_ids());
type_ids.extend(std::iter::repeat(type_id).take(encoding.len()));
tokens.extend(encoding.get_tokens().iter().map(|s| s.to_owned()));
words.extend(encoding.get_word_ids());
offsets.extend(encoding.get_offsets());
special_tokens_mask.extend(encoding.get_special_tokens_mask());
attention_mask.extend(encoding.get_attention_mask());
}
Piece::Sequence {
id: Sequence::B,
type_id,
} => {
let pair = pair.as_ref().expect("Missing pair sequence, checked above");
let seq_start = ids.len();
let seq_end = seq_start + pair.len();
sequence_ranges.insert(1, seq_start..seq_end);
ids.extend(pair.get_ids());
type_ids.extend(std::iter::repeat(type_id).take(pair.len()));
tokens.extend(pair.get_tokens().iter().map(|s| s.to_owned()));
words.extend(pair.get_word_ids());
offsets.extend(pair.get_offsets());
special_tokens_mask.extend(pair.get_special_tokens_mask());
attention_mask.extend(pair.get_attention_mask());
}
Piece::SpecialToken { id, type_id } => {
if add_special_tokens {
let tok = &self.special_tokens.0[id]; // We already checked existance above
let len = tok.ids.len();
//let mut pair = if encodings.len() > 1 {
// Some(encodings.pop().unwrap())
//} else {
// None
//};
//let mut encoding = encodings.pop().unwrap();
ids.extend(&tok.ids);
type_ids.extend(std::iter::repeat(type_id).take(len));
tokens.extend(tok.tokens.clone());
words.extend(std::iter::repeat(None).take(len));
offsets.extend(std::iter::repeat((0, 0)).take(len));
special_tokens_mask.extend(std::iter::repeat(1).take(len));
attention_mask.extend(std::iter::repeat(1).take(len));
}
}
}
}
//let pair_overflowing = pair.as_mut().map_or(vec![], |e| e.take_overflowing());
//let mut overflowing: Vec<Encoding> = encoding
// .take_overflowing()
// .iter()
// .map(|encoding| -> Result<Vec<Encoding>> {
// // 1. The pair itself
// let mut overflowings = self.apply_template(
// template,
// if encodings.len() > 1 {
// vec![encoding.clone(), encodings[1].clone()]
// } else {
// vec![encoding.clone()]
// },
// add_special_tokens,
// )?;
Ok(Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens_mask,
attention_mask,
overflowing,
sequence_ranges,
))
// // 2. Its overflowings
// for other_o in &pair_overflowing {
// overflowings.extend(self.apply_template(
// template,
// vec![encoding.clone(), other_o.clone()],
// add_special_tokens,
// )?);
// }
// Ok(overflowings)
// })
// .collect::<Result<Vec<Vec<Encoding>>>>()?
// .into_iter()
// .flatten()
// .collect();
//// We also need to combine the first sequence with all other overflowings
//overflowing.extend(
// pair_overflowing
// .into_iter()
// .map(|pair| {
// self.apply_template(template, vec![encoding.clone(), pair], add_special_tokens)
// })
// .collect::<Result<Vec<_>>>()?
// .into_iter()
// .flatten(),
//);
Ok(final_encodings)
}
}
@ -632,39 +589,34 @@ impl PostProcessor for TemplateProcessing {
fn process_encodings(
&self,
mut encodings: Vec<Encoding>,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Vec<Encoding>> {
let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
// let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
// 1 => (
// encodings
// .pop()
// .ok_or(ProcessorError::InvalidEncodingsVecLength)?,
// None,
// ),
// 2 => {
// let pair = encodings
// .pop()
// .ok_or(ProcessorError::InvalidEncodingsVecLength)?;
// let encoding = encodings
// .pop()
// .ok_or(ProcessorError::InvalidEncodingsVecLength)?;
// (encoding, Some(pair))
// }
// _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
// };
let template = match encodings.len() {
2 => &self.pair.0,
1 => &self.single.0,
_ => todo!(),
};
let encoding = self.apply_template(
if pair.is_some() {
&self.pair.0
} else {
&self.single.0
},
encoding,
pair,
add_special_tokens,
)?;
Ok(vec![encoding])
let encodings = self.apply_template(template, encodings, add_special_tokens)?;
Ok(encodings)
}
}
@ -884,7 +836,6 @@ mod tests {
);
let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
assert_eq!(
single_encoding,
Encoding::new(
@ -906,6 +857,7 @@ mod tests {
);
assert_eq!(single_encoding.token_to_sequence(2), Some(0));
assert_eq!(single_encoding.token_to_sequence(3), None);
let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
assert_eq!(
pair_encoding,
Encoding::new(

View File

@ -152,6 +152,10 @@ impl Encoding {
&self.type_ids
}
pub fn set_type_ids(&mut self, type_ids: Vec<u32>) {
self.type_ids = type_ids;
}
pub fn get_offsets(&self) -> &[Offsets] {
&self.offsets
}
@ -383,6 +387,12 @@ impl Encoding {
pub fn merge<I: IntoIterator<Item = Encoding>>(encodings: I, growing_offsets: bool) -> Self {
let mut encoding = Encoding::default();
// TODO this is suboptimal as we're doing this iteratively instead of preallocating
// all the encodings sizes all at once and only copying into this preallocated vector
// https://github.com/huggingface/tokenizers/pull/1049
// In order to fix, we just need to preallocate all vectors, then copy everything
// into it (and deal with overlowings correctly)
for sub in encodings {
encoding.merge_with(sub, growing_offsets);
}