mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Making process_encodings
not eat up the encodings any more. (#1051)
* Making `process_encodings` not eat up the encodings any more. * Fixing clippy.
This commit is contained in:
@ -34,6 +34,10 @@ harness = false
|
||||
name = "bert_benchmark"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "layout_benchmark"
|
||||
harness = false
|
||||
|
||||
[dependencies]
|
||||
lazy_static = "1.4"
|
||||
rand = "0.8"
|
||||
|
76
tokenizers/benches/layout_benchmark.rs
Normal file
76
tokenizers/benches/layout_benchmark.rs
Normal file
@ -0,0 +1,76 @@
|
||||
#[macro_use]
|
||||
extern crate criterion;
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::path::Path;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use criterion::black_box;
|
||||
use criterion::Criterion;
|
||||
use tokenizers::processors::template::TemplateProcessing;
|
||||
use tokenizers::{EncodeInput, Encoding, PostProcessor, Tokenizer};
|
||||
|
||||
/// Simple TemplateProcessing
|
||||
fn create_processor() -> TemplateProcessing {
|
||||
TemplateProcessing::builder()
|
||||
.try_single("[CLS]:0 $A:0 [SEP]:0")
|
||||
.unwrap()
|
||||
.try_pair("[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1")
|
||||
.unwrap()
|
||||
.special_tokens(vec![("[CLS]", 0), ("[SEP]", 1)])
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn bench_layout(c: &mut Criterion) {
|
||||
let processor = create_processor();
|
||||
let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap();
|
||||
let mut encodeds: Vec<Encoding> = vec![];
|
||||
for line in BufReader::new(File::open(Path::new("data/big.txt")).unwrap()).lines() {
|
||||
let line: EncodeInput = line.unwrap().into();
|
||||
|
||||
let encoded: Encoding = tokenizer.encode(line, false).unwrap();
|
||||
encodeds.push(encoded);
|
||||
}
|
||||
|
||||
c.bench_function("TemplateProcessing single encode", |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let mut duration = Duration::new(0, 0);
|
||||
for i in 0..iters as usize {
|
||||
let encoded_index = i % encodeds.len();
|
||||
let encoded: Encoding = encodeds[encoded_index].clone();
|
||||
|
||||
let start = Instant::now();
|
||||
let _ = black_box(processor.process(encoded, None, false));
|
||||
duration = duration.checked_add(start.elapsed()).unwrap();
|
||||
}
|
||||
duration
|
||||
})
|
||||
});
|
||||
c.bench_function("TemplateProcessing pair encode", |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let mut duration = Duration::new(0, 0);
|
||||
for i in 0..iters as usize {
|
||||
let encoded_index = i % encodeds.len();
|
||||
let encoded: Encoding = encodeds[encoded_index].clone();
|
||||
|
||||
let encoded_index2 = (i + 1) % encodeds.len();
|
||||
let pair: Encoding = encodeds[encoded_index2].clone();
|
||||
|
||||
let start = Instant::now();
|
||||
let _ = black_box(processor.process(encoded, Some(pair), false));
|
||||
duration = duration.checked_add(start.elapsed()).unwrap();
|
||||
}
|
||||
duration
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = layout_benches;
|
||||
config = Criterion::default().sample_size(20);
|
||||
targets = bench_layout
|
||||
}
|
||||
|
||||
criterion_main!(layout_benches);
|
@ -1,4 +1,4 @@
|
||||
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
|
||||
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::iter::FromIterator;
|
||||
@ -49,53 +49,11 @@ impl PostProcessor for BertProcessing {
|
||||
return Ok(encodings);
|
||||
}
|
||||
|
||||
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
|
||||
1 => (
|
||||
encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
||||
None,
|
||||
),
|
||||
2 => {
|
||||
let pair = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
let encoding = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
(encoding, Some(pair))
|
||||
}
|
||||
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
||||
};
|
||||
|
||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||
let tokens = [
|
||||
&[self.cls.0.clone()],
|
||||
encoding.get_tokens(),
|
||||
&[self.sep.0.clone()],
|
||||
]
|
||||
.concat();
|
||||
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
||||
let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
||||
let attention_mask = vec![1; ids.len()];
|
||||
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||
// the special tokens.
|
||||
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
||||
let mut new_encoding = Encoding::new(
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
words,
|
||||
offsets,
|
||||
special_tokens,
|
||||
attention_mask,
|
||||
encoding
|
||||
.take_overflowing()
|
||||
.into_iter()
|
||||
.map(|encoding| {
|
||||
let encodings: Vec<Encoding> = encodings
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(i, encoding)| {
|
||||
if i == 0 {
|
||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||
let tokens = [
|
||||
@ -110,8 +68,8 @@ impl PostProcessor for BertProcessing {
|
||||
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
||||
let attention_mask = vec![1; ids.len()];
|
||||
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
|
||||
// contain the special tokens.
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||
// the special tokens.
|
||||
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
||||
Encoding::new(
|
||||
ids,
|
||||
@ -121,72 +79,105 @@ impl PostProcessor for BertProcessing {
|
||||
offsets,
|
||||
special_tokens,
|
||||
attention_mask,
|
||||
vec![],
|
||||
encoding
|
||||
.take_overflowing()
|
||||
.into_iter()
|
||||
.map(|encoding| {
|
||||
let ids =
|
||||
[&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||
let tokens = [
|
||||
&[self.cls.0.clone()],
|
||||
encoding.get_tokens(),
|
||||
&[self.sep.0.clone()],
|
||||
]
|
||||
.concat();
|
||||
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
||||
let offsets =
|
||||
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let special_tokens =
|
||||
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
|
||||
.concat();
|
||||
let attention_mask = vec![1; ids.len()];
|
||||
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
|
||||
// contain the special tokens.
|
||||
let sequence_ranges =
|
||||
HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
||||
Encoding::new(
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
words,
|
||||
offsets,
|
||||
special_tokens,
|
||||
attention_mask,
|
||||
vec![],
|
||||
sequence_ranges,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
sequence_ranges,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
sequence_ranges,
|
||||
);
|
||||
} else {
|
||||
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
|
||||
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
|
||||
let pair_words = [encoding.get_word_ids(), &[None]].concat();
|
||||
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let pair_special_tokens =
|
||||
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
|
||||
if let Some(mut encoding) = pair_encoding {
|
||||
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
|
||||
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
|
||||
let pair_words = [encoding.get_word_ids(), &[None]].concat();
|
||||
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let pair_special_tokens =
|
||||
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||
// the special tokens.
|
||||
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
|
||||
Encoding::new(
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
pair_words,
|
||||
pair_offsets,
|
||||
pair_special_tokens,
|
||||
pair_attention_mask,
|
||||
encoding
|
||||
.take_overflowing()
|
||||
.into_iter()
|
||||
.map(|encoding| {
|
||||
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
|
||||
let pair_tokens =
|
||||
[encoding.get_tokens(), &[self.sep.0.clone()]].concat();
|
||||
let pair_words = [encoding.get_word_ids(), &[None]].concat();
|
||||
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let pair_special_tokens =
|
||||
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||
// the special tokens.
|
||||
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
|
||||
let new_pair_encoding = Encoding::new(
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
pair_words,
|
||||
pair_offsets,
|
||||
pair_special_tokens,
|
||||
pair_attention_mask,
|
||||
encoding
|
||||
.take_overflowing()
|
||||
.into_iter()
|
||||
.map(|encoding| {
|
||||
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
|
||||
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
|
||||
let pair_words = [encoding.get_word_ids(), &[None]].concat();
|
||||
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let pair_special_tokens =
|
||||
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges
|
||||
// shouldn't contain the special tokens.
|
||||
let pair_sequence_ranges =
|
||||
HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
|
||||
Encoding::new(
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
pair_words,
|
||||
pair_offsets,
|
||||
pair_special_tokens,
|
||||
pair_attention_mask,
|
||||
vec![],
|
||||
pair_sequence_ranges,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
pair_sequence_ranges,
|
||||
)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges
|
||||
// shouldn't contain the special tokens.
|
||||
let pair_sequence_ranges =
|
||||
HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
|
||||
Encoding::new(
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
pair_words,
|
||||
pair_offsets,
|
||||
pair_special_tokens,
|
||||
pair_attention_mask,
|
||||
vec![],
|
||||
pair_sequence_ranges,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
pair_sequence_ranges,
|
||||
);
|
||||
|
||||
new_encoding.merge_with(new_pair_encoding, false);
|
||||
}
|
||||
|
||||
Ok(vec![new_encoding])
|
||||
Ok(encodings)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::processors::byte_level::process_offsets;
|
||||
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
|
||||
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::iter::FromIterator;
|
||||
@ -74,53 +74,11 @@ impl PostProcessor for RobertaProcessing {
|
||||
return Ok(encodings);
|
||||
}
|
||||
|
||||
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
|
||||
1 => (
|
||||
encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
||||
None,
|
||||
),
|
||||
2 => {
|
||||
let pair = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
let encoding = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
(encoding, Some(pair))
|
||||
}
|
||||
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
||||
};
|
||||
|
||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||
let tokens = [
|
||||
&[self.cls.0.clone()],
|
||||
encoding.get_tokens(),
|
||||
&[self.sep.0.clone()],
|
||||
]
|
||||
.concat();
|
||||
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
||||
let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
||||
let attention_mask = vec![1; ids.len()];
|
||||
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||
// the special tokens.
|
||||
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
||||
let mut new_encoding = Encoding::new(
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
words,
|
||||
offsets,
|
||||
special_tokens,
|
||||
attention_mask,
|
||||
encoding
|
||||
.take_overflowing()
|
||||
.into_iter()
|
||||
.map(|encoding| {
|
||||
let encodings: Vec<Encoding> = encodings
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(i, encoding)| {
|
||||
if i == 0 {
|
||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||
let tokens = [
|
||||
@ -135,8 +93,8 @@ impl PostProcessor for RobertaProcessing {
|
||||
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
||||
let attention_mask = vec![1; ids.len()];
|
||||
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
|
||||
// contain the special tokens.
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||
// the special tokens.
|
||||
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
||||
Encoding::new(
|
||||
ids,
|
||||
@ -146,82 +104,118 @@ impl PostProcessor for RobertaProcessing {
|
||||
offsets,
|
||||
special_tokens,
|
||||
attention_mask,
|
||||
vec![],
|
||||
encoding
|
||||
.take_overflowing()
|
||||
.into_iter()
|
||||
.map(|encoding| {
|
||||
let ids =
|
||||
[&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||
let tokens = [
|
||||
&[self.cls.0.clone()],
|
||||
encoding.get_tokens(),
|
||||
&[self.sep.0.clone()],
|
||||
]
|
||||
.concat();
|
||||
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
||||
let offsets =
|
||||
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let special_tokens =
|
||||
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
|
||||
.concat();
|
||||
let attention_mask = vec![1; ids.len()];
|
||||
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
|
||||
// contain the special tokens.
|
||||
let sequence_ranges =
|
||||
HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
||||
Encoding::new(
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
words,
|
||||
offsets,
|
||||
special_tokens,
|
||||
attention_mask,
|
||||
vec![],
|
||||
sequence_ranges,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
sequence_ranges,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
sequence_ranges,
|
||||
);
|
||||
} else {
|
||||
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
|
||||
let pair_tokens = [
|
||||
&[self.sep.0.clone()],
|
||||
encoding.get_tokens(),
|
||||
&[self.sep.0.clone()],
|
||||
]
|
||||
.concat();
|
||||
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
||||
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let pair_special_tokens =
|
||||
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
|
||||
if let Some(mut encoding) = pair_encoding {
|
||||
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
|
||||
let pair_tokens = [
|
||||
&[self.sep.0.clone()],
|
||||
encoding.get_tokens(),
|
||||
&[self.sep.0.clone()],
|
||||
]
|
||||
.concat();
|
||||
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
||||
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let pair_special_tokens =
|
||||
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||
// the special tokens.
|
||||
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
|
||||
Encoding::new(
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
pair_words,
|
||||
pair_offsets,
|
||||
pair_special_tokens,
|
||||
pair_attention_mask,
|
||||
encoding
|
||||
.take_overflowing()
|
||||
.into_iter()
|
||||
.map(|encoding| {
|
||||
let pair_ids =
|
||||
[&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
|
||||
let pair_tokens = [
|
||||
&[self.sep.0.clone()],
|
||||
encoding.get_tokens(),
|
||||
&[self.sep.0.clone()],
|
||||
]
|
||||
.concat();
|
||||
let pair_words =
|
||||
[&[None], encoding.get_word_ids(), &[None]].concat();
|
||||
let pair_offsets =
|
||||
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let pair_special_tokens =
|
||||
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]]
|
||||
.concat();
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||
// the special tokens.
|
||||
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
|
||||
let new_pair_encoding = Encoding::new(
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
pair_words,
|
||||
pair_offsets,
|
||||
pair_special_tokens,
|
||||
pair_attention_mask,
|
||||
encoding
|
||||
.take_overflowing()
|
||||
.into_iter()
|
||||
.map(|encoding| {
|
||||
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
|
||||
let pair_tokens = [
|
||||
&[self.sep.0.clone()],
|
||||
encoding.get_tokens(),
|
||||
&[self.sep.0.clone()],
|
||||
]
|
||||
.concat();
|
||||
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
||||
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||
let pair_special_tokens =
|
||||
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges
|
||||
// shouldn't contain the special tokens.
|
||||
let pair_sequence_ranges =
|
||||
HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
|
||||
Encoding::new(
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
pair_words,
|
||||
pair_offsets,
|
||||
pair_special_tokens,
|
||||
pair_attention_mask,
|
||||
vec![],
|
||||
pair_sequence_ranges,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
pair_sequence_ranges,
|
||||
)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// For compatibility with `TemplateProcessing`, the sequence_ranges
|
||||
// shouldn't contain the special tokens.
|
||||
let pair_sequence_ranges =
|
||||
HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
|
||||
Encoding::new(
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
pair_words,
|
||||
pair_offsets,
|
||||
pair_special_tokens,
|
||||
pair_attention_mask,
|
||||
vec![],
|
||||
pair_sequence_ranges,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
pair_sequence_ranges,
|
||||
);
|
||||
|
||||
new_encoding.merge_with(new_pair_encoding, false);
|
||||
}
|
||||
|
||||
Ok(vec![new_encoding])
|
||||
Ok(encodings)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,14 +55,14 @@
|
||||
//!
|
||||
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
|
||||
//!
|
||||
use crate::{tokenizer::ProcessorError, Encoding, PostProcessor, Result};
|
||||
use crate::{Encoding, PostProcessor, Result};
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::result::Result as StdResult;
|
||||
|
||||
/// Represents both sequences received as input of the PostProcessor
|
||||
/// Represents any sequences received as input of the PostProcessor
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
|
||||
pub enum Sequence {
|
||||
/// This is the first sequence, the one that is always specified
|
||||
@ -479,145 +479,102 @@ impl TemplateProcessing {
|
||||
fn apply_template(
|
||||
&self,
|
||||
template: &[Piece],
|
||||
mut encoding: Encoding,
|
||||
mut pair: Option<Encoding>,
|
||||
mut encodings: Vec<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
// Compute the new size
|
||||
let mut new_len = 0;
|
||||
for piece in template {
|
||||
new_len += match piece {
|
||||
Piece::Sequence {
|
||||
id: Sequence::A, ..
|
||||
} => encoding.len(),
|
||||
Piece::Sequence {
|
||||
id: Sequence::B, ..
|
||||
} => pair
|
||||
.as_ref()
|
||||
.ok_or("Template expected a pair sequence, but none provided")?
|
||||
.len(),
|
||||
Piece::SpecialToken { id, .. } => {
|
||||
if add_special_tokens {
|
||||
self.special_tokens
|
||||
.0
|
||||
.get(id)
|
||||
.ok_or_else(|| format!("Missing SpecialToken with id {}", id))?
|
||||
.ids
|
||||
.len()
|
||||
} else {
|
||||
0
|
||||
) -> Result<Vec<Encoding>> {
|
||||
let final_encodings: Vec<Encoding> = template
|
||||
.iter()
|
||||
.flat_map(|piece| {
|
||||
match piece {
|
||||
Piece::Sequence { id, type_id } => {
|
||||
let i = if *id == Sequence::A { 0 } else { 1 };
|
||||
let encoding = &mut encodings[i];
|
||||
encoding.set_type_ids(vec![*type_id; encoding.len()]);
|
||||
encoding.set_sequence_id(i);
|
||||
Some(encoding.clone())
|
||||
}
|
||||
Piece::SpecialToken { id, type_id } => {
|
||||
if add_special_tokens {
|
||||
let tok = &self.special_tokens.0[id]; // We already checked existance above
|
||||
let len = tok.ids.len();
|
||||
|
||||
let encoding = Encoding::new(
|
||||
tok.ids.clone(),
|
||||
std::iter::repeat(*type_id).take(len).collect(),
|
||||
tok.tokens.clone(),
|
||||
// words
|
||||
std::iter::repeat(None).take(len).collect(),
|
||||
// offsets
|
||||
std::iter::repeat((0, 0)).take(len).collect(),
|
||||
// special_tokens_mask
|
||||
std::iter::repeat(1).take(len).collect(),
|
||||
// attention_mask
|
||||
std::iter::repeat(1).take(len).collect(),
|
||||
// overflowing
|
||||
vec![],
|
||||
// sequence_range
|
||||
HashMap::new(),
|
||||
);
|
||||
Some(encoding)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Then build the new Encoding
|
||||
let mut ids = Vec::with_capacity(new_len);
|
||||
let mut type_ids = Vec::with_capacity(new_len);
|
||||
let mut tokens = Vec::with_capacity(new_len);
|
||||
let mut words = Vec::with_capacity(new_len);
|
||||
let mut offsets = Vec::with_capacity(new_len);
|
||||
let mut special_tokens_mask = Vec::with_capacity(new_len);
|
||||
let mut attention_mask = Vec::with_capacity(new_len);
|
||||
let mut sequence_ranges = HashMap::new();
|
||||
|
||||
let pair_overflowing = pair.as_mut().map_or(vec![], |e| e.take_overflowing());
|
||||
let mut overflowing = encoding
|
||||
.take_overflowing()
|
||||
.into_iter()
|
||||
.flat_map(|encoding| {
|
||||
// 1. The pair itself
|
||||
let mut overflowings = vec![self.apply_template(
|
||||
template,
|
||||
encoding.clone(),
|
||||
pair.clone(),
|
||||
add_special_tokens,
|
||||
)];
|
||||
|
||||
// 2. Its overflowings
|
||||
for other_o in &pair_overflowing {
|
||||
overflowings.push(self.apply_template(
|
||||
template,
|
||||
encoding.clone(),
|
||||
Some(other_o.clone()),
|
||||
add_special_tokens,
|
||||
));
|
||||
}
|
||||
|
||||
overflowings
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
// We also need to combine the first sequence with all other overflowings
|
||||
overflowing.extend(
|
||||
pair_overflowing
|
||||
.into_iter()
|
||||
.map(|pair| {
|
||||
self.apply_template(template, encoding.clone(), Some(pair), add_special_tokens)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?,
|
||||
);
|
||||
.collect();
|
||||
|
||||
for piece in template {
|
||||
match piece {
|
||||
Piece::Sequence {
|
||||
id: Sequence::A,
|
||||
type_id,
|
||||
} => {
|
||||
let seq_start = ids.len();
|
||||
let seq_end = seq_start + encoding.len();
|
||||
sequence_ranges.insert(0, seq_start..seq_end);
|
||||
ids.extend(encoding.get_ids());
|
||||
type_ids.extend(std::iter::repeat(type_id).take(encoding.len()));
|
||||
tokens.extend(encoding.get_tokens().iter().map(|s| s.to_owned()));
|
||||
words.extend(encoding.get_word_ids());
|
||||
offsets.extend(encoding.get_offsets());
|
||||
special_tokens_mask.extend(encoding.get_special_tokens_mask());
|
||||
attention_mask.extend(encoding.get_attention_mask());
|
||||
}
|
||||
Piece::Sequence {
|
||||
id: Sequence::B,
|
||||
type_id,
|
||||
} => {
|
||||
let pair = pair.as_ref().expect("Missing pair sequence, checked above");
|
||||
let seq_start = ids.len();
|
||||
let seq_end = seq_start + pair.len();
|
||||
sequence_ranges.insert(1, seq_start..seq_end);
|
||||
ids.extend(pair.get_ids());
|
||||
type_ids.extend(std::iter::repeat(type_id).take(pair.len()));
|
||||
tokens.extend(pair.get_tokens().iter().map(|s| s.to_owned()));
|
||||
words.extend(pair.get_word_ids());
|
||||
offsets.extend(pair.get_offsets());
|
||||
special_tokens_mask.extend(pair.get_special_tokens_mask());
|
||||
attention_mask.extend(pair.get_attention_mask());
|
||||
}
|
||||
Piece::SpecialToken { id, type_id } => {
|
||||
if add_special_tokens {
|
||||
let tok = &self.special_tokens.0[id]; // We already checked existance above
|
||||
let len = tok.ids.len();
|
||||
//let mut pair = if encodings.len() > 1 {
|
||||
// Some(encodings.pop().unwrap())
|
||||
//} else {
|
||||
// None
|
||||
//};
|
||||
//let mut encoding = encodings.pop().unwrap();
|
||||
|
||||
ids.extend(&tok.ids);
|
||||
type_ids.extend(std::iter::repeat(type_id).take(len));
|
||||
tokens.extend(tok.tokens.clone());
|
||||
words.extend(std::iter::repeat(None).take(len));
|
||||
offsets.extend(std::iter::repeat((0, 0)).take(len));
|
||||
special_tokens_mask.extend(std::iter::repeat(1).take(len));
|
||||
attention_mask.extend(std::iter::repeat(1).take(len));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//let pair_overflowing = pair.as_mut().map_or(vec![], |e| e.take_overflowing());
|
||||
//let mut overflowing: Vec<Encoding> = encoding
|
||||
// .take_overflowing()
|
||||
// .iter()
|
||||
// .map(|encoding| -> Result<Vec<Encoding>> {
|
||||
// // 1. The pair itself
|
||||
// let mut overflowings = self.apply_template(
|
||||
// template,
|
||||
// if encodings.len() > 1 {
|
||||
// vec![encoding.clone(), encodings[1].clone()]
|
||||
// } else {
|
||||
// vec![encoding.clone()]
|
||||
// },
|
||||
// add_special_tokens,
|
||||
// )?;
|
||||
|
||||
Ok(Encoding::new(
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
words,
|
||||
offsets,
|
||||
special_tokens_mask,
|
||||
attention_mask,
|
||||
overflowing,
|
||||
sequence_ranges,
|
||||
))
|
||||
// // 2. Its overflowings
|
||||
// for other_o in &pair_overflowing {
|
||||
// overflowings.extend(self.apply_template(
|
||||
// template,
|
||||
// vec![encoding.clone(), other_o.clone()],
|
||||
// add_special_tokens,
|
||||
// )?);
|
||||
// }
|
||||
|
||||
// Ok(overflowings)
|
||||
// })
|
||||
// .collect::<Result<Vec<Vec<Encoding>>>>()?
|
||||
// .into_iter()
|
||||
// .flatten()
|
||||
// .collect();
|
||||
//// We also need to combine the first sequence with all other overflowings
|
||||
//overflowing.extend(
|
||||
// pair_overflowing
|
||||
// .into_iter()
|
||||
// .map(|pair| {
|
||||
// self.apply_template(template, vec![encoding.clone(), pair], add_special_tokens)
|
||||
// })
|
||||
// .collect::<Result<Vec<_>>>()?
|
||||
// .into_iter()
|
||||
// .flatten(),
|
||||
//);
|
||||
|
||||
Ok(final_encodings)
|
||||
}
|
||||
}
|
||||
|
||||
@ -632,39 +589,34 @@ impl PostProcessor for TemplateProcessing {
|
||||
|
||||
fn process_encodings(
|
||||
&self,
|
||||
mut encodings: Vec<Encoding>,
|
||||
encodings: Vec<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Vec<Encoding>> {
|
||||
let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
|
||||
1 => (
|
||||
encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
||||
None,
|
||||
),
|
||||
2 => {
|
||||
let pair = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
let encoding = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
(encoding, Some(pair))
|
||||
}
|
||||
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
||||
// let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
|
||||
// 1 => (
|
||||
// encodings
|
||||
// .pop()
|
||||
// .ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
||||
// None,
|
||||
// ),
|
||||
// 2 => {
|
||||
// let pair = encodings
|
||||
// .pop()
|
||||
// .ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
// let encoding = encodings
|
||||
// .pop()
|
||||
// .ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
// (encoding, Some(pair))
|
||||
// }
|
||||
// _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
||||
// };
|
||||
let template = match encodings.len() {
|
||||
2 => &self.pair.0,
|
||||
1 => &self.single.0,
|
||||
_ => todo!(),
|
||||
};
|
||||
|
||||
let encoding = self.apply_template(
|
||||
if pair.is_some() {
|
||||
&self.pair.0
|
||||
} else {
|
||||
&self.single.0
|
||||
},
|
||||
encoding,
|
||||
pair,
|
||||
add_special_tokens,
|
||||
)?;
|
||||
Ok(vec![encoding])
|
||||
let encodings = self.apply_template(template, encodings, add_special_tokens)?;
|
||||
Ok(encodings)
|
||||
}
|
||||
}
|
||||
|
||||
@ -884,7 +836,6 @@ mod tests {
|
||||
);
|
||||
let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
|
||||
let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
|
||||
let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
|
||||
assert_eq!(
|
||||
single_encoding,
|
||||
Encoding::new(
|
||||
@ -906,6 +857,7 @@ mod tests {
|
||||
);
|
||||
assert_eq!(single_encoding.token_to_sequence(2), Some(0));
|
||||
assert_eq!(single_encoding.token_to_sequence(3), None);
|
||||
let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
|
||||
assert_eq!(
|
||||
pair_encoding,
|
||||
Encoding::new(
|
||||
|
@ -152,6 +152,10 @@ impl Encoding {
|
||||
&self.type_ids
|
||||
}
|
||||
|
||||
pub fn set_type_ids(&mut self, type_ids: Vec<u32>) {
|
||||
self.type_ids = type_ids;
|
||||
}
|
||||
|
||||
pub fn get_offsets(&self) -> &[Offsets] {
|
||||
&self.offsets
|
||||
}
|
||||
@ -383,6 +387,12 @@ impl Encoding {
|
||||
pub fn merge<I: IntoIterator<Item = Encoding>>(encodings: I, growing_offsets: bool) -> Self {
|
||||
let mut encoding = Encoding::default();
|
||||
|
||||
// TODO this is suboptimal as we're doing this iteratively instead of preallocating
|
||||
// all the encodings sizes all at once and only copying into this preallocated vector
|
||||
// https://github.com/huggingface/tokenizers/pull/1049
|
||||
|
||||
// In order to fix, we just need to preallocate all vectors, then copy everything
|
||||
// into it (and deal with overlowings correctly)
|
||||
for sub in encodings {
|
||||
encoding.merge_with(sub, growing_offsets);
|
||||
}
|
||||
|
Reference in New Issue
Block a user