mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Making process_encodings
not eat up the encodings any more. (#1051)
* Making `process_encodings` not eat up the encodings any more. * Fixing clippy.
This commit is contained in:
@ -34,6 +34,10 @@ harness = false
|
|||||||
name = "bert_benchmark"
|
name = "bert_benchmark"
|
||||||
harness = false
|
harness = false
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "layout_benchmark"
|
||||||
|
harness = false
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
lazy_static = "1.4"
|
lazy_static = "1.4"
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
|
76
tokenizers/benches/layout_benchmark.rs
Normal file
76
tokenizers/benches/layout_benchmark.rs
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
#[macro_use]
|
||||||
|
extern crate criterion;
|
||||||
|
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::{BufRead, BufReader};
|
||||||
|
use std::path::Path;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use criterion::black_box;
|
||||||
|
use criterion::Criterion;
|
||||||
|
use tokenizers::processors::template::TemplateProcessing;
|
||||||
|
use tokenizers::{EncodeInput, Encoding, PostProcessor, Tokenizer};
|
||||||
|
|
||||||
|
/// Simple TemplateProcessing
|
||||||
|
fn create_processor() -> TemplateProcessing {
|
||||||
|
TemplateProcessing::builder()
|
||||||
|
.try_single("[CLS]:0 $A:0 [SEP]:0")
|
||||||
|
.unwrap()
|
||||||
|
.try_pair("[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1")
|
||||||
|
.unwrap()
|
||||||
|
.special_tokens(vec![("[CLS]", 0), ("[SEP]", 1)])
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bench_layout(c: &mut Criterion) {
|
||||||
|
let processor = create_processor();
|
||||||
|
let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap();
|
||||||
|
let mut encodeds: Vec<Encoding> = vec![];
|
||||||
|
for line in BufReader::new(File::open(Path::new("data/big.txt")).unwrap()).lines() {
|
||||||
|
let line: EncodeInput = line.unwrap().into();
|
||||||
|
|
||||||
|
let encoded: Encoding = tokenizer.encode(line, false).unwrap();
|
||||||
|
encodeds.push(encoded);
|
||||||
|
}
|
||||||
|
|
||||||
|
c.bench_function("TemplateProcessing single encode", |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let mut duration = Duration::new(0, 0);
|
||||||
|
for i in 0..iters as usize {
|
||||||
|
let encoded_index = i % encodeds.len();
|
||||||
|
let encoded: Encoding = encodeds[encoded_index].clone();
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
|
let _ = black_box(processor.process(encoded, None, false));
|
||||||
|
duration = duration.checked_add(start.elapsed()).unwrap();
|
||||||
|
}
|
||||||
|
duration
|
||||||
|
})
|
||||||
|
});
|
||||||
|
c.bench_function("TemplateProcessing pair encode", |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let mut duration = Duration::new(0, 0);
|
||||||
|
for i in 0..iters as usize {
|
||||||
|
let encoded_index = i % encodeds.len();
|
||||||
|
let encoded: Encoding = encodeds[encoded_index].clone();
|
||||||
|
|
||||||
|
let encoded_index2 = (i + 1) % encodeds.len();
|
||||||
|
let pair: Encoding = encodeds[encoded_index2].clone();
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
|
let _ = black_box(processor.process(encoded, Some(pair), false));
|
||||||
|
duration = duration.checked_add(start.elapsed()).unwrap();
|
||||||
|
}
|
||||||
|
duration
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group! {
|
||||||
|
name = layout_benches;
|
||||||
|
config = Criterion::default().sample_size(20);
|
||||||
|
targets = bench_layout
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_main!(layout_benches);
|
@ -1,4 +1,4 @@
|
|||||||
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
|
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::iter::FromIterator;
|
use std::iter::FromIterator;
|
||||||
@ -49,53 +49,11 @@ impl PostProcessor for BertProcessing {
|
|||||||
return Ok(encodings);
|
return Ok(encodings);
|
||||||
}
|
}
|
||||||
|
|
||||||
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
|
let encodings: Vec<Encoding> = encodings
|
||||||
1 => (
|
.iter_mut()
|
||||||
encodings
|
.enumerate()
|
||||||
.pop()
|
.map(|(i, encoding)| {
|
||||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
if i == 0 {
|
||||||
None,
|
|
||||||
),
|
|
||||||
2 => {
|
|
||||||
let pair = encodings
|
|
||||||
.pop()
|
|
||||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
|
||||||
let encoding = encodings
|
|
||||||
.pop()
|
|
||||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
|
||||||
(encoding, Some(pair))
|
|
||||||
}
|
|
||||||
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
|
||||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
|
||||||
let tokens = [
|
|
||||||
&[self.cls.0.clone()],
|
|
||||||
encoding.get_tokens(),
|
|
||||||
&[self.sep.0.clone()],
|
|
||||||
]
|
|
||||||
.concat();
|
|
||||||
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
|
||||||
let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
|
||||||
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
|
||||||
let attention_mask = vec![1; ids.len()];
|
|
||||||
|
|
||||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
|
||||||
// the special tokens.
|
|
||||||
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
|
||||||
let mut new_encoding = Encoding::new(
|
|
||||||
ids,
|
|
||||||
type_ids,
|
|
||||||
tokens,
|
|
||||||
words,
|
|
||||||
offsets,
|
|
||||||
special_tokens,
|
|
||||||
attention_mask,
|
|
||||||
encoding
|
|
||||||
.take_overflowing()
|
|
||||||
.into_iter()
|
|
||||||
.map(|encoding| {
|
|
||||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||||
let tokens = [
|
let tokens = [
|
||||||
@ -110,8 +68,8 @@ impl PostProcessor for BertProcessing {
|
|||||||
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
||||||
let attention_mask = vec![1; ids.len()];
|
let attention_mask = vec![1; ids.len()];
|
||||||
|
|
||||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
|
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||||
// contain the special tokens.
|
// the special tokens.
|
||||||
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
||||||
Encoding::new(
|
Encoding::new(
|
||||||
ids,
|
ids,
|
||||||
@ -121,72 +79,105 @@ impl PostProcessor for BertProcessing {
|
|||||||
offsets,
|
offsets,
|
||||||
special_tokens,
|
special_tokens,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
vec![],
|
encoding
|
||||||
|
.take_overflowing()
|
||||||
|
.into_iter()
|
||||||
|
.map(|encoding| {
|
||||||
|
let ids =
|
||||||
|
[&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
|
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||||
|
let tokens = [
|
||||||
|
&[self.cls.0.clone()],
|
||||||
|
encoding.get_tokens(),
|
||||||
|
&[self.sep.0.clone()],
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
||||||
|
let offsets =
|
||||||
|
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||||
|
let special_tokens =
|
||||||
|
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
|
||||||
|
.concat();
|
||||||
|
let attention_mask = vec![1; ids.len()];
|
||||||
|
|
||||||
|
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
|
||||||
|
// contain the special tokens.
|
||||||
|
let sequence_ranges =
|
||||||
|
HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
||||||
|
Encoding::new(
|
||||||
|
ids,
|
||||||
|
type_ids,
|
||||||
|
tokens,
|
||||||
|
words,
|
||||||
|
offsets,
|
||||||
|
special_tokens,
|
||||||
|
attention_mask,
|
||||||
|
vec![],
|
||||||
|
sequence_ranges,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
sequence_ranges,
|
sequence_ranges,
|
||||||
)
|
)
|
||||||
})
|
} else {
|
||||||
.collect(),
|
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
sequence_ranges,
|
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
|
||||||
);
|
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
|
||||||
|
let pair_words = [encoding.get_word_ids(), &[None]].concat();
|
||||||
|
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
|
||||||
|
let pair_special_tokens =
|
||||||
|
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||||
|
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||||
|
|
||||||
if let Some(mut encoding) = pair_encoding {
|
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||||
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
|
// the special tokens.
|
||||||
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
|
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
|
||||||
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
|
Encoding::new(
|
||||||
let pair_words = [encoding.get_word_ids(), &[None]].concat();
|
pair_ids,
|
||||||
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
|
pair_type_ids,
|
||||||
let pair_special_tokens =
|
pair_tokens,
|
||||||
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
pair_words,
|
||||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
pair_offsets,
|
||||||
|
pair_special_tokens,
|
||||||
|
pair_attention_mask,
|
||||||
|
encoding
|
||||||
|
.take_overflowing()
|
||||||
|
.into_iter()
|
||||||
|
.map(|encoding| {
|
||||||
|
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
|
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
|
||||||
|
let pair_tokens =
|
||||||
|
[encoding.get_tokens(), &[self.sep.0.clone()]].concat();
|
||||||
|
let pair_words = [encoding.get_word_ids(), &[None]].concat();
|
||||||
|
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
|
||||||
|
let pair_special_tokens =
|
||||||
|
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||||
|
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||||
|
|
||||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
// For compatibility with `TemplateProcessing`, the sequence_ranges
|
||||||
// the special tokens.
|
// shouldn't contain the special tokens.
|
||||||
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
|
let pair_sequence_ranges =
|
||||||
let new_pair_encoding = Encoding::new(
|
HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
|
||||||
pair_ids,
|
Encoding::new(
|
||||||
pair_type_ids,
|
pair_ids,
|
||||||
pair_tokens,
|
pair_type_ids,
|
||||||
pair_words,
|
pair_tokens,
|
||||||
pair_offsets,
|
pair_words,
|
||||||
pair_special_tokens,
|
pair_offsets,
|
||||||
pair_attention_mask,
|
pair_special_tokens,
|
||||||
encoding
|
pair_attention_mask,
|
||||||
.take_overflowing()
|
vec![],
|
||||||
.into_iter()
|
pair_sequence_ranges,
|
||||||
.map(|encoding| {
|
)
|
||||||
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
|
})
|
||||||
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
|
.collect(),
|
||||||
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat();
|
pair_sequence_ranges,
|
||||||
let pair_words = [encoding.get_word_ids(), &[None]].concat();
|
)
|
||||||
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
|
}
|
||||||
let pair_special_tokens =
|
})
|
||||||
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
.collect();
|
||||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
|
||||||
|
|
||||||
// For compatibility with `TemplateProcessing`, the sequence_ranges
|
Ok(encodings)
|
||||||
// shouldn't contain the special tokens.
|
|
||||||
let pair_sequence_ranges =
|
|
||||||
HashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
|
|
||||||
Encoding::new(
|
|
||||||
pair_ids,
|
|
||||||
pair_type_ids,
|
|
||||||
pair_tokens,
|
|
||||||
pair_words,
|
|
||||||
pair_offsets,
|
|
||||||
pair_special_tokens,
|
|
||||||
pair_attention_mask,
|
|
||||||
vec![],
|
|
||||||
pair_sequence_ranges,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
pair_sequence_ranges,
|
|
||||||
);
|
|
||||||
|
|
||||||
new_encoding.merge_with(new_pair_encoding, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(vec![new_encoding])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::processors::byte_level::process_offsets;
|
use crate::processors::byte_level::process_offsets;
|
||||||
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
|
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::iter::FromIterator;
|
use std::iter::FromIterator;
|
||||||
@ -74,53 +74,11 @@ impl PostProcessor for RobertaProcessing {
|
|||||||
return Ok(encodings);
|
return Ok(encodings);
|
||||||
}
|
}
|
||||||
|
|
||||||
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
|
let encodings: Vec<Encoding> = encodings
|
||||||
1 => (
|
.iter_mut()
|
||||||
encodings
|
.enumerate()
|
||||||
.pop()
|
.map(|(i, encoding)| {
|
||||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
if i == 0 {
|
||||||
None,
|
|
||||||
),
|
|
||||||
2 => {
|
|
||||||
let pair = encodings
|
|
||||||
.pop()
|
|
||||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
|
||||||
let encoding = encodings
|
|
||||||
.pop()
|
|
||||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
|
||||||
(encoding, Some(pair))
|
|
||||||
}
|
|
||||||
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
|
||||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
|
||||||
let tokens = [
|
|
||||||
&[self.cls.0.clone()],
|
|
||||||
encoding.get_tokens(),
|
|
||||||
&[self.sep.0.clone()],
|
|
||||||
]
|
|
||||||
.concat();
|
|
||||||
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
|
||||||
let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
|
||||||
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
|
||||||
let attention_mask = vec![1; ids.len()];
|
|
||||||
|
|
||||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
|
||||||
// the special tokens.
|
|
||||||
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
|
||||||
let mut new_encoding = Encoding::new(
|
|
||||||
ids,
|
|
||||||
type_ids,
|
|
||||||
tokens,
|
|
||||||
words,
|
|
||||||
offsets,
|
|
||||||
special_tokens,
|
|
||||||
attention_mask,
|
|
||||||
encoding
|
|
||||||
.take_overflowing()
|
|
||||||
.into_iter()
|
|
||||||
.map(|encoding| {
|
|
||||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||||
let tokens = [
|
let tokens = [
|
||||||
@ -135,8 +93,8 @@ impl PostProcessor for RobertaProcessing {
|
|||||||
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
||||||
let attention_mask = vec![1; ids.len()];
|
let attention_mask = vec![1; ids.len()];
|
||||||
|
|
||||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
|
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||||
// contain the special tokens.
|
// the special tokens.
|
||||||
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
||||||
Encoding::new(
|
Encoding::new(
|
||||||
ids,
|
ids,
|
||||||
@ -146,82 +104,118 @@ impl PostProcessor for RobertaProcessing {
|
|||||||
offsets,
|
offsets,
|
||||||
special_tokens,
|
special_tokens,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
vec![],
|
encoding
|
||||||
|
.take_overflowing()
|
||||||
|
.into_iter()
|
||||||
|
.map(|encoding| {
|
||||||
|
let ids =
|
||||||
|
[&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
|
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||||
|
let tokens = [
|
||||||
|
&[self.cls.0.clone()],
|
||||||
|
encoding.get_tokens(),
|
||||||
|
&[self.sep.0.clone()],
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
||||||
|
let offsets =
|
||||||
|
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||||
|
let special_tokens =
|
||||||
|
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
|
||||||
|
.concat();
|
||||||
|
let attention_mask = vec![1; ids.len()];
|
||||||
|
|
||||||
|
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
|
||||||
|
// contain the special tokens.
|
||||||
|
let sequence_ranges =
|
||||||
|
HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
|
||||||
|
Encoding::new(
|
||||||
|
ids,
|
||||||
|
type_ids,
|
||||||
|
tokens,
|
||||||
|
words,
|
||||||
|
offsets,
|
||||||
|
special_tokens,
|
||||||
|
attention_mask,
|
||||||
|
vec![],
|
||||||
|
sequence_ranges,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
sequence_ranges,
|
sequence_ranges,
|
||||||
)
|
)
|
||||||
})
|
} else {
|
||||||
.collect(),
|
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
sequence_ranges,
|
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
|
||||||
);
|
let pair_tokens = [
|
||||||
|
&[self.sep.0.clone()],
|
||||||
|
encoding.get_tokens(),
|
||||||
|
&[self.sep.0.clone()],
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
||||||
|
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||||
|
let pair_special_tokens =
|
||||||
|
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||||
|
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||||
|
|
||||||
if let Some(mut encoding) = pair_encoding {
|
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
||||||
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
|
// the special tokens.
|
||||||
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
|
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
|
||||||
let pair_tokens = [
|
Encoding::new(
|
||||||
&[self.sep.0.clone()],
|
pair_ids,
|
||||||
encoding.get_tokens(),
|
pair_type_ids,
|
||||||
&[self.sep.0.clone()],
|
pair_tokens,
|
||||||
]
|
pair_words,
|
||||||
.concat();
|
pair_offsets,
|
||||||
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
pair_special_tokens,
|
||||||
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
pair_attention_mask,
|
||||||
let pair_special_tokens =
|
encoding
|
||||||
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
.take_overflowing()
|
||||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
.into_iter()
|
||||||
|
.map(|encoding| {
|
||||||
|
let pair_ids =
|
||||||
|
[&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
|
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
|
||||||
|
let pair_tokens = [
|
||||||
|
&[self.sep.0.clone()],
|
||||||
|
encoding.get_tokens(),
|
||||||
|
&[self.sep.0.clone()],
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
let pair_words =
|
||||||
|
[&[None], encoding.get_word_ids(), &[None]].concat();
|
||||||
|
let pair_offsets =
|
||||||
|
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
||||||
|
let pair_special_tokens =
|
||||||
|
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]]
|
||||||
|
.concat();
|
||||||
|
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||||
|
|
||||||
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
|
// For compatibility with `TemplateProcessing`, the sequence_ranges
|
||||||
// the special tokens.
|
// shouldn't contain the special tokens.
|
||||||
let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
|
let pair_sequence_ranges =
|
||||||
let new_pair_encoding = Encoding::new(
|
HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
|
||||||
pair_ids,
|
Encoding::new(
|
||||||
pair_type_ids,
|
pair_ids,
|
||||||
pair_tokens,
|
pair_type_ids,
|
||||||
pair_words,
|
pair_tokens,
|
||||||
pair_offsets,
|
pair_words,
|
||||||
pair_special_tokens,
|
pair_offsets,
|
||||||
pair_attention_mask,
|
pair_special_tokens,
|
||||||
encoding
|
pair_attention_mask,
|
||||||
.take_overflowing()
|
vec![],
|
||||||
.into_iter()
|
pair_sequence_ranges,
|
||||||
.map(|encoding| {
|
)
|
||||||
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
|
})
|
||||||
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
|
.collect(),
|
||||||
let pair_tokens = [
|
pair_sequence_ranges,
|
||||||
&[self.sep.0.clone()],
|
)
|
||||||
encoding.get_tokens(),
|
}
|
||||||
&[self.sep.0.clone()],
|
})
|
||||||
]
|
.collect();
|
||||||
.concat();
|
|
||||||
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
|
|
||||||
let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
|
|
||||||
let pair_special_tokens =
|
|
||||||
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
|
||||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
|
||||||
|
|
||||||
// For compatibility with `TemplateProcessing`, the sequence_ranges
|
Ok(encodings)
|
||||||
// shouldn't contain the special tokens.
|
|
||||||
let pair_sequence_ranges =
|
|
||||||
HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
|
|
||||||
Encoding::new(
|
|
||||||
pair_ids,
|
|
||||||
pair_type_ids,
|
|
||||||
pair_tokens,
|
|
||||||
pair_words,
|
|
||||||
pair_offsets,
|
|
||||||
pair_special_tokens,
|
|
||||||
pair_attention_mask,
|
|
||||||
vec![],
|
|
||||||
pair_sequence_ranges,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
pair_sequence_ranges,
|
|
||||||
);
|
|
||||||
|
|
||||||
new_encoding.merge_with(new_pair_encoding, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(vec![new_encoding])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,14 +55,14 @@
|
|||||||
//!
|
//!
|
||||||
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
|
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
|
||||||
//!
|
//!
|
||||||
use crate::{tokenizer::ProcessorError, Encoding, PostProcessor, Result};
|
use crate::{Encoding, PostProcessor, Result};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::convert::{TryFrom, TryInto};
|
use std::convert::{TryFrom, TryInto};
|
||||||
use std::result::Result as StdResult;
|
use std::result::Result as StdResult;
|
||||||
|
|
||||||
/// Represents both sequences received as input of the PostProcessor
|
/// Represents any sequences received as input of the PostProcessor
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
|
||||||
pub enum Sequence {
|
pub enum Sequence {
|
||||||
/// This is the first sequence, the one that is always specified
|
/// This is the first sequence, the one that is always specified
|
||||||
@ -479,145 +479,102 @@ impl TemplateProcessing {
|
|||||||
fn apply_template(
|
fn apply_template(
|
||||||
&self,
|
&self,
|
||||||
template: &[Piece],
|
template: &[Piece],
|
||||||
mut encoding: Encoding,
|
mut encodings: Vec<Encoding>,
|
||||||
mut pair: Option<Encoding>,
|
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> Result<Encoding> {
|
) -> Result<Vec<Encoding>> {
|
||||||
// Compute the new size
|
let final_encodings: Vec<Encoding> = template
|
||||||
let mut new_len = 0;
|
.iter()
|
||||||
for piece in template {
|
.flat_map(|piece| {
|
||||||
new_len += match piece {
|
match piece {
|
||||||
Piece::Sequence {
|
Piece::Sequence { id, type_id } => {
|
||||||
id: Sequence::A, ..
|
let i = if *id == Sequence::A { 0 } else { 1 };
|
||||||
} => encoding.len(),
|
let encoding = &mut encodings[i];
|
||||||
Piece::Sequence {
|
encoding.set_type_ids(vec![*type_id; encoding.len()]);
|
||||||
id: Sequence::B, ..
|
encoding.set_sequence_id(i);
|
||||||
} => pair
|
Some(encoding.clone())
|
||||||
.as_ref()
|
}
|
||||||
.ok_or("Template expected a pair sequence, but none provided")?
|
Piece::SpecialToken { id, type_id } => {
|
||||||
.len(),
|
if add_special_tokens {
|
||||||
Piece::SpecialToken { id, .. } => {
|
let tok = &self.special_tokens.0[id]; // We already checked existance above
|
||||||
if add_special_tokens {
|
let len = tok.ids.len();
|
||||||
self.special_tokens
|
|
||||||
.0
|
let encoding = Encoding::new(
|
||||||
.get(id)
|
tok.ids.clone(),
|
||||||
.ok_or_else(|| format!("Missing SpecialToken with id {}", id))?
|
std::iter::repeat(*type_id).take(len).collect(),
|
||||||
.ids
|
tok.tokens.clone(),
|
||||||
.len()
|
// words
|
||||||
} else {
|
std::iter::repeat(None).take(len).collect(),
|
||||||
0
|
// offsets
|
||||||
|
std::iter::repeat((0, 0)).take(len).collect(),
|
||||||
|
// special_tokens_mask
|
||||||
|
std::iter::repeat(1).take(len).collect(),
|
||||||
|
// attention_mask
|
||||||
|
std::iter::repeat(1).take(len).collect(),
|
||||||
|
// overflowing
|
||||||
|
vec![],
|
||||||
|
// sequence_range
|
||||||
|
HashMap::new(),
|
||||||
|
);
|
||||||
|
Some(encoding)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then build the new Encoding
|
|
||||||
let mut ids = Vec::with_capacity(new_len);
|
|
||||||
let mut type_ids = Vec::with_capacity(new_len);
|
|
||||||
let mut tokens = Vec::with_capacity(new_len);
|
|
||||||
let mut words = Vec::with_capacity(new_len);
|
|
||||||
let mut offsets = Vec::with_capacity(new_len);
|
|
||||||
let mut special_tokens_mask = Vec::with_capacity(new_len);
|
|
||||||
let mut attention_mask = Vec::with_capacity(new_len);
|
|
||||||
let mut sequence_ranges = HashMap::new();
|
|
||||||
|
|
||||||
let pair_overflowing = pair.as_mut().map_or(vec![], |e| e.take_overflowing());
|
|
||||||
let mut overflowing = encoding
|
|
||||||
.take_overflowing()
|
|
||||||
.into_iter()
|
|
||||||
.flat_map(|encoding| {
|
|
||||||
// 1. The pair itself
|
|
||||||
let mut overflowings = vec![self.apply_template(
|
|
||||||
template,
|
|
||||||
encoding.clone(),
|
|
||||||
pair.clone(),
|
|
||||||
add_special_tokens,
|
|
||||||
)];
|
|
||||||
|
|
||||||
// 2. Its overflowings
|
|
||||||
for other_o in &pair_overflowing {
|
|
||||||
overflowings.push(self.apply_template(
|
|
||||||
template,
|
|
||||||
encoding.clone(),
|
|
||||||
Some(other_o.clone()),
|
|
||||||
add_special_tokens,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
overflowings
|
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect();
|
||||||
// We also need to combine the first sequence with all other overflowings
|
|
||||||
overflowing.extend(
|
|
||||||
pair_overflowing
|
|
||||||
.into_iter()
|
|
||||||
.map(|pair| {
|
|
||||||
self.apply_template(template, encoding.clone(), Some(pair), add_special_tokens)
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>>>()?,
|
|
||||||
);
|
|
||||||
|
|
||||||
for piece in template {
|
//let mut pair = if encodings.len() > 1 {
|
||||||
match piece {
|
// Some(encodings.pop().unwrap())
|
||||||
Piece::Sequence {
|
//} else {
|
||||||
id: Sequence::A,
|
// None
|
||||||
type_id,
|
//};
|
||||||
} => {
|
//let mut encoding = encodings.pop().unwrap();
|
||||||
let seq_start = ids.len();
|
|
||||||
let seq_end = seq_start + encoding.len();
|
|
||||||
sequence_ranges.insert(0, seq_start..seq_end);
|
|
||||||
ids.extend(encoding.get_ids());
|
|
||||||
type_ids.extend(std::iter::repeat(type_id).take(encoding.len()));
|
|
||||||
tokens.extend(encoding.get_tokens().iter().map(|s| s.to_owned()));
|
|
||||||
words.extend(encoding.get_word_ids());
|
|
||||||
offsets.extend(encoding.get_offsets());
|
|
||||||
special_tokens_mask.extend(encoding.get_special_tokens_mask());
|
|
||||||
attention_mask.extend(encoding.get_attention_mask());
|
|
||||||
}
|
|
||||||
Piece::Sequence {
|
|
||||||
id: Sequence::B,
|
|
||||||
type_id,
|
|
||||||
} => {
|
|
||||||
let pair = pair.as_ref().expect("Missing pair sequence, checked above");
|
|
||||||
let seq_start = ids.len();
|
|
||||||
let seq_end = seq_start + pair.len();
|
|
||||||
sequence_ranges.insert(1, seq_start..seq_end);
|
|
||||||
ids.extend(pair.get_ids());
|
|
||||||
type_ids.extend(std::iter::repeat(type_id).take(pair.len()));
|
|
||||||
tokens.extend(pair.get_tokens().iter().map(|s| s.to_owned()));
|
|
||||||
words.extend(pair.get_word_ids());
|
|
||||||
offsets.extend(pair.get_offsets());
|
|
||||||
special_tokens_mask.extend(pair.get_special_tokens_mask());
|
|
||||||
attention_mask.extend(pair.get_attention_mask());
|
|
||||||
}
|
|
||||||
Piece::SpecialToken { id, type_id } => {
|
|
||||||
if add_special_tokens {
|
|
||||||
let tok = &self.special_tokens.0[id]; // We already checked existance above
|
|
||||||
let len = tok.ids.len();
|
|
||||||
|
|
||||||
ids.extend(&tok.ids);
|
//let pair_overflowing = pair.as_mut().map_or(vec![], |e| e.take_overflowing());
|
||||||
type_ids.extend(std::iter::repeat(type_id).take(len));
|
//let mut overflowing: Vec<Encoding> = encoding
|
||||||
tokens.extend(tok.tokens.clone());
|
// .take_overflowing()
|
||||||
words.extend(std::iter::repeat(None).take(len));
|
// .iter()
|
||||||
offsets.extend(std::iter::repeat((0, 0)).take(len));
|
// .map(|encoding| -> Result<Vec<Encoding>> {
|
||||||
special_tokens_mask.extend(std::iter::repeat(1).take(len));
|
// // 1. The pair itself
|
||||||
attention_mask.extend(std::iter::repeat(1).take(len));
|
// let mut overflowings = self.apply_template(
|
||||||
}
|
// template,
|
||||||
}
|
// if encodings.len() > 1 {
|
||||||
}
|
// vec![encoding.clone(), encodings[1].clone()]
|
||||||
}
|
// } else {
|
||||||
|
// vec![encoding.clone()]
|
||||||
|
// },
|
||||||
|
// add_special_tokens,
|
||||||
|
// )?;
|
||||||
|
|
||||||
Ok(Encoding::new(
|
// // 2. Its overflowings
|
||||||
ids,
|
// for other_o in &pair_overflowing {
|
||||||
type_ids,
|
// overflowings.extend(self.apply_template(
|
||||||
tokens,
|
// template,
|
||||||
words,
|
// vec![encoding.clone(), other_o.clone()],
|
||||||
offsets,
|
// add_special_tokens,
|
||||||
special_tokens_mask,
|
// )?);
|
||||||
attention_mask,
|
// }
|
||||||
overflowing,
|
|
||||||
sequence_ranges,
|
// Ok(overflowings)
|
||||||
))
|
// })
|
||||||
|
// .collect::<Result<Vec<Vec<Encoding>>>>()?
|
||||||
|
// .into_iter()
|
||||||
|
// .flatten()
|
||||||
|
// .collect();
|
||||||
|
//// We also need to combine the first sequence with all other overflowings
|
||||||
|
//overflowing.extend(
|
||||||
|
// pair_overflowing
|
||||||
|
// .into_iter()
|
||||||
|
// .map(|pair| {
|
||||||
|
// self.apply_template(template, vec![encoding.clone(), pair], add_special_tokens)
|
||||||
|
// })
|
||||||
|
// .collect::<Result<Vec<_>>>()?
|
||||||
|
// .into_iter()
|
||||||
|
// .flatten(),
|
||||||
|
//);
|
||||||
|
|
||||||
|
Ok(final_encodings)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -632,39 +589,34 @@ impl PostProcessor for TemplateProcessing {
|
|||||||
|
|
||||||
fn process_encodings(
|
fn process_encodings(
|
||||||
&self,
|
&self,
|
||||||
mut encodings: Vec<Encoding>,
|
encodings: Vec<Encoding>,
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> Result<Vec<Encoding>> {
|
) -> Result<Vec<Encoding>> {
|
||||||
let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
|
// let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
|
||||||
1 => (
|
// 1 => (
|
||||||
encodings
|
// encodings
|
||||||
.pop()
|
// .pop()
|
||||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
// .ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
||||||
None,
|
// None,
|
||||||
),
|
// ),
|
||||||
2 => {
|
// 2 => {
|
||||||
let pair = encodings
|
// let pair = encodings
|
||||||
.pop()
|
// .pop()
|
||||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
// .ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||||
let encoding = encodings
|
// let encoding = encodings
|
||||||
.pop()
|
// .pop()
|
||||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
// .ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||||
(encoding, Some(pair))
|
// (encoding, Some(pair))
|
||||||
}
|
// }
|
||||||
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
// _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
||||||
|
// };
|
||||||
|
let template = match encodings.len() {
|
||||||
|
2 => &self.pair.0,
|
||||||
|
1 => &self.single.0,
|
||||||
|
_ => todo!(),
|
||||||
};
|
};
|
||||||
|
let encodings = self.apply_template(template, encodings, add_special_tokens)?;
|
||||||
let encoding = self.apply_template(
|
Ok(encodings)
|
||||||
if pair.is_some() {
|
|
||||||
&self.pair.0
|
|
||||||
} else {
|
|
||||||
&self.single.0
|
|
||||||
},
|
|
||||||
encoding,
|
|
||||||
pair,
|
|
||||||
add_special_tokens,
|
|
||||||
)?;
|
|
||||||
Ok(vec![encoding])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -884,7 +836,6 @@ mod tests {
|
|||||||
);
|
);
|
||||||
let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
|
let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
|
||||||
let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
|
let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
|
||||||
let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
single_encoding,
|
single_encoding,
|
||||||
Encoding::new(
|
Encoding::new(
|
||||||
@ -906,6 +857,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
assert_eq!(single_encoding.token_to_sequence(2), Some(0));
|
assert_eq!(single_encoding.token_to_sequence(2), Some(0));
|
||||||
assert_eq!(single_encoding.token_to_sequence(3), None);
|
assert_eq!(single_encoding.token_to_sequence(3), None);
|
||||||
|
let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
pair_encoding,
|
pair_encoding,
|
||||||
Encoding::new(
|
Encoding::new(
|
||||||
|
@ -152,6 +152,10 @@ impl Encoding {
|
|||||||
&self.type_ids
|
&self.type_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_type_ids(&mut self, type_ids: Vec<u32>) {
|
||||||
|
self.type_ids = type_ids;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_offsets(&self) -> &[Offsets] {
|
pub fn get_offsets(&self) -> &[Offsets] {
|
||||||
&self.offsets
|
&self.offsets
|
||||||
}
|
}
|
||||||
@ -383,6 +387,12 @@ impl Encoding {
|
|||||||
pub fn merge<I: IntoIterator<Item = Encoding>>(encodings: I, growing_offsets: bool) -> Self {
|
pub fn merge<I: IntoIterator<Item = Encoding>>(encodings: I, growing_offsets: bool) -> Self {
|
||||||
let mut encoding = Encoding::default();
|
let mut encoding = Encoding::default();
|
||||||
|
|
||||||
|
// TODO this is suboptimal as we're doing this iteratively instead of preallocating
|
||||||
|
// all the encodings sizes all at once and only copying into this preallocated vector
|
||||||
|
// https://github.com/huggingface/tokenizers/pull/1049
|
||||||
|
|
||||||
|
// In order to fix, we just need to preallocate all vectors, then copy everything
|
||||||
|
// into it (and deal with overlowings correctly)
|
||||||
for sub in encodings {
|
for sub in encodings {
|
||||||
encoding.merge_with(sub, growing_offsets);
|
encoding.merge_with(sub, growing_offsets);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user