mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Tokenizer handles Truncation and Padding
This commit is contained in:
@ -1,56 +1,30 @@
|
|||||||
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
||||||
use crate::utils::{truncate_encodings, TruncationStrategy};
|
|
||||||
|
|
||||||
pub struct BertProcessing {
|
pub struct BertProcessing {
|
||||||
max_len: usize,
|
|
||||||
trunc_strategy: TruncationStrategy,
|
|
||||||
trunc_stride: usize,
|
|
||||||
|
|
||||||
sep: (String, u32),
|
sep: (String, u32),
|
||||||
cls: (String, u32),
|
cls: (String, u32),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BertProcessing {
|
impl BertProcessing {
|
||||||
pub fn new(
|
pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
|
||||||
max_len: usize,
|
BertProcessing { sep, cls }
|
||||||
trunc_strategy: TruncationStrategy,
|
|
||||||
trunc_stride: usize,
|
|
||||||
sep: (String, u32),
|
|
||||||
cls: (String, u32),
|
|
||||||
) -> Self {
|
|
||||||
BertProcessing {
|
|
||||||
max_len,
|
|
||||||
trunc_strategy,
|
|
||||||
trunc_stride,
|
|
||||||
sep,
|
|
||||||
cls,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PostProcessor for BertProcessing {
|
impl PostProcessor for BertProcessing {
|
||||||
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
|
fn added_tokens(
|
||||||
let special_token_len = if pair_encoding.is_some() { 3 } else { 2 };
|
&self,
|
||||||
let total_len = encoding.get_ids().len()
|
_encoding: &Encoding,
|
||||||
+ pair_encoding
|
pair_encoding: &Option<Encoding>,
|
||||||
.as_ref()
|
) -> Result<usize> {
|
||||||
.map(|e| e.get_ids().len())
|
if pair_encoding.is_some() {
|
||||||
.unwrap_or(0)
|
Ok(3)
|
||||||
+ special_token_len;
|
|
||||||
|
|
||||||
let need_trunc = if total_len > self.max_len {
|
|
||||||
total_len - self.max_len
|
|
||||||
} else {
|
} else {
|
||||||
0
|
Ok(2)
|
||||||
};
|
}
|
||||||
let (mut encoding, pair_encoding) = truncate_encodings(
|
}
|
||||||
encoding,
|
|
||||||
pair_encoding,
|
|
||||||
need_trunc,
|
|
||||||
self.trunc_strategy,
|
|
||||||
self.trunc_stride,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
|
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
|
||||||
// Prepare ids
|
// Prepare ids
|
||||||
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
|
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
|
||||||
let pair_ids = pair_encoding
|
let pair_ids = pair_encoding
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
/// The various possible padding directions
|
/// The various possible padding directions
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub enum PaddingDirection {
|
pub enum PaddingDirection {
|
||||||
Left,
|
Left,
|
||||||
Right,
|
Right,
|
||||||
@ -155,8 +156,8 @@ impl Encoding {
|
|||||||
pad_length: usize,
|
pad_length: usize,
|
||||||
pad_id: u32,
|
pad_id: u32,
|
||||||
pad_type_id: u32,
|
pad_type_id: u32,
|
||||||
pad_token: String,
|
pad_token: &str,
|
||||||
direction: PaddingDirection,
|
direction: &PaddingDirection,
|
||||||
) {
|
) {
|
||||||
if self.ids.len() > pad_length {
|
if self.ids.len() > pad_length {
|
||||||
// We just do nothing if the wanted padding length is smaller than us
|
// We just do nothing if the wanted padding length is smaller than us
|
||||||
@ -166,13 +167,13 @@ impl Encoding {
|
|||||||
|
|
||||||
let ids_pad = vec![pad_id; pad_length];
|
let ids_pad = vec![pad_id; pad_length];
|
||||||
let type_ids_pad = vec![pad_type_id; pad_length];
|
let type_ids_pad = vec![pad_type_id; pad_length];
|
||||||
let tokens_pad = vec![pad_token; pad_length];
|
let tokens_pad = vec![pad_token.to_owned(); pad_length];
|
||||||
let attention_pad = vec![0; pad_length];
|
let attention_pad = vec![0; pad_length];
|
||||||
let special_pad = vec![1; pad_length];
|
let special_pad = vec![1; pad_length];
|
||||||
let offsets_pad = vec![(0, 0); pad_length];
|
let offsets_pad = vec![(0, 0); pad_length];
|
||||||
|
|
||||||
match direction {
|
match direction {
|
||||||
Left => {
|
PaddingDirection::Left => {
|
||||||
self.ids = [&ids_pad[..], &self.ids[..]].concat();
|
self.ids = [&ids_pad[..], &self.ids[..]].concat();
|
||||||
self.type_ids = [&type_ids_pad[..], &self.type_ids[..]].concat();
|
self.type_ids = [&type_ids_pad[..], &self.type_ids[..]].concat();
|
||||||
self.tokens = [&tokens_pad[..], &self.tokens[..]].concat();
|
self.tokens = [&tokens_pad[..], &self.tokens[..]].concat();
|
||||||
@ -181,7 +182,7 @@ impl Encoding {
|
|||||||
[&special_pad[..], &self.special_tokens_mask[..]].concat();
|
[&special_pad[..], &self.special_tokens_mask[..]].concat();
|
||||||
self.offsets = [&offsets_pad[..], &self.offsets[..]].concat();
|
self.offsets = [&offsets_pad[..], &self.offsets[..]].concat();
|
||||||
}
|
}
|
||||||
Right => {
|
PaddingDirection::Right => {
|
||||||
self.ids = [&self.ids[..], &ids_pad[..]].concat();
|
self.ids = [&self.ids[..], &ids_pad[..]].concat();
|
||||||
self.type_ids = [&self.type_ids[..], &type_ids_pad[..]].concat();
|
self.type_ids = [&self.type_ids[..], &type_ids_pad[..]].concat();
|
||||||
self.tokens = [&self.tokens[..], &tokens_pad[..]].concat();
|
self.tokens = [&self.tokens[..], &tokens_pad[..]].concat();
|
||||||
|
@ -12,6 +12,9 @@
|
|||||||
//! - PostProcessor: Takes care of the processing after tokenization. (Like truncating, padding,
|
//! - PostProcessor: Takes care of the processing after tokenization. (Like truncating, padding,
|
||||||
//! ...)
|
//! ...)
|
||||||
//!
|
//!
|
||||||
|
use crate::utils::{
|
||||||
|
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
|
||||||
|
};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
@ -20,7 +23,7 @@ use std::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
mod encoding;
|
mod encoding;
|
||||||
pub use encoding::Encoding;
|
pub use encoding::*;
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
|
||||||
@ -46,6 +49,7 @@ pub trait Model {
|
|||||||
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
|
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
|
||||||
/// Truncating, Padding, etc... are PostProcessor steps
|
/// Truncating, Padding, etc... are PostProcessor steps
|
||||||
pub trait PostProcessor {
|
pub trait PostProcessor {
|
||||||
|
fn added_tokens(&self, encoding: &Encoding, pair_encoding: &Option<Encoding>) -> Result<usize>;
|
||||||
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding>;
|
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,6 +129,10 @@ pub struct Tokenizer {
|
|||||||
added_tokens: HashMap<AddedToken, u32>,
|
added_tokens: HashMap<AddedToken, u32>,
|
||||||
added_tokens_r: HashMap<u32, AddedToken>,
|
added_tokens_r: HashMap<u32, AddedToken>,
|
||||||
split_re: Option<regex::Regex>,
|
split_re: Option<regex::Regex>,
|
||||||
|
|
||||||
|
// General processing parameters
|
||||||
|
trunc: Option<TruncationParams>,
|
||||||
|
padding: Option<PaddingParams>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Tokenizer {
|
impl Tokenizer {
|
||||||
@ -139,11 +147,13 @@ impl Tokenizer {
|
|||||||
added_tokens: HashMap::new(),
|
added_tokens: HashMap::new(),
|
||||||
added_tokens_r: HashMap::new(),
|
added_tokens_r: HashMap::new(),
|
||||||
split_re: None,
|
split_re: None,
|
||||||
|
trunc: None,
|
||||||
|
padding: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the normalizers
|
/// Set the normalizer
|
||||||
pub fn with_normalizers(&mut self, normalizer: Box<dyn Normalizer + Sync>) -> &Self {
|
pub fn with_normalizer(&mut self, normalizer: Box<dyn Normalizer + Sync>) -> &Self {
|
||||||
self.normalizer = Some(normalizer);
|
self.normalizer = Some(normalizer);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@ -172,6 +182,18 @@ impl Tokenizer {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the truncation parameters
|
||||||
|
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> &Self {
|
||||||
|
self.trunc = trunc;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the padding strategy
|
||||||
|
pub fn with_padding(&mut self, padding: Option<PaddingParams>) -> &Self {
|
||||||
|
self.padding = padding;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the size of the vocabulary
|
/// Get the size of the vocabulary
|
||||||
pub fn get_vocab_size(&self) -> usize {
|
pub fn get_vocab_size(&self) -> usize {
|
||||||
self.model.get_vocab_size()
|
self.model.get_vocab_size()
|
||||||
@ -283,10 +305,17 @@ impl Tokenizer {
|
|||||||
|
|
||||||
/// Encode all the sentences in parallel, using multiple threads
|
/// Encode all the sentences in parallel, using multiple threads
|
||||||
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Result<Vec<Encoding>> {
|
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Result<Vec<Encoding>> {
|
||||||
inputs
|
let encodings = inputs
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
.map(|input| self.encode(input))
|
.map(|input| self.encode(input))
|
||||||
.collect()
|
.collect::<Result<Vec<Encoding>>>()?;
|
||||||
|
|
||||||
|
if let Some(params) = &self.padding {
|
||||||
|
// We do the padding here to make sure we handle the batch padding
|
||||||
|
pad_encodings(encodings, ¶ms)
|
||||||
|
} else {
|
||||||
|
Ok(encodings)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decode the given ids, back to a String
|
/// Decode the given ids, back to a String
|
||||||
@ -376,20 +405,61 @@ impl Tokenizer {
|
|||||||
/// Post processing logic, handling the case where there is no PostProcessor set
|
/// Post processing logic, handling the case where there is no PostProcessor set
|
||||||
fn post_process(
|
fn post_process(
|
||||||
&self,
|
&self,
|
||||||
mut encoding: Encoding,
|
encoding: Encoding,
|
||||||
pair_encoding: Option<Encoding>,
|
pair_encoding: Option<Encoding>,
|
||||||
) -> Result<Encoding> {
|
) -> Result<Encoding> {
|
||||||
if let Some(processor) = &self.post_processor {
|
// 1. First we truncate if needed
|
||||||
processor.process(encoding, pair_encoding)
|
let (mut encoding, pair_encoding) = {
|
||||||
|
if let Some(trunc) = &self.trunc {
|
||||||
|
let n_added_tokens = if let Some(processor) = &self.post_processor {
|
||||||
|
processor.added_tokens(&encoding, &pair_encoding)?
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
if n_added_tokens > 0 {
|
||||||
|
let params = TruncationParams {
|
||||||
|
max_length: trunc.max_length - n_added_tokens,
|
||||||
|
..*trunc
|
||||||
|
};
|
||||||
|
truncate_encodings(encoding, pair_encoding, ¶ms)?
|
||||||
|
} else {
|
||||||
|
truncate_encodings(encoding, pair_encoding, &trunc)?
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(encoding, pair_encoding)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// 2. Then We post process
|
||||||
|
let mut final_encoding = if let Some(processor) = &self.post_processor {
|
||||||
|
processor.process(encoding, pair_encoding)?
|
||||||
} else {
|
} else {
|
||||||
match pair_encoding {
|
match pair_encoding {
|
||||||
None => Ok(encoding),
|
None => encoding,
|
||||||
Some(pair) => {
|
Some(pair) => {
|
||||||
encoding.merge_with(pair);
|
encoding.merge_with(pair);
|
||||||
Ok(encoding)
|
encoding
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// 3. Then we pad if needed
|
||||||
|
if let Some(params) = &self.padding {
|
||||||
|
// We can only pad for a given size. If the Strategy is BatchLongest, it will be done
|
||||||
|
// when we handle a batch
|
||||||
|
if let PaddingStrategy::Fixed(size) = params.strategy {
|
||||||
|
final_encoding.pad(
|
||||||
|
size,
|
||||||
|
params.pad_id,
|
||||||
|
params.pad_type_id,
|
||||||
|
¶ms.pad_token,
|
||||||
|
¶ms.direction,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(final_encoding)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add the given tokens to the added vocabulary
|
/// Add the given tokens to the added vocabulary
|
||||||
|
@ -1,15 +1,35 @@
|
|||||||
use crate::tokenizer::{Encoding, Result};
|
use crate::tokenizer::{Encoding, PaddingDirection, Result};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TruncationParams {
|
||||||
|
pub max_length: usize,
|
||||||
|
pub strategy: TruncationStrategy,
|
||||||
|
pub stride: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PaddingParams {
|
||||||
|
pub strategy: PaddingStrategy,
|
||||||
|
pub direction: PaddingDirection,
|
||||||
|
pub pad_id: u32,
|
||||||
|
pub pad_type_id: u32,
|
||||||
|
pub pad_token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum PaddingStrategy {
|
||||||
|
BatchLongest,
|
||||||
|
Fixed(usize),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
SequenceTooSmall,
|
|
||||||
SecondSequenceNotProvided,
|
SecondSequenceNotProvided,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Error {
|
impl std::fmt::Display for Error {
|
||||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
Error::SequenceTooSmall => write!(fmt, "Truncation error: Sequence is too small"),
|
|
||||||
Error::SecondSequenceNotProvided => {
|
Error::SecondSequenceNotProvided => {
|
||||||
write!(fmt, "Truncation error: Second sequence not provided")
|
write!(fmt, "Truncation error: Second sequence not provided")
|
||||||
}
|
}
|
||||||
@ -18,7 +38,7 @@ impl std::fmt::Display for Error {
|
|||||||
}
|
}
|
||||||
impl std::error::Error for Error {}
|
impl std::error::Error for Error {}
|
||||||
|
|
||||||
#[derive(Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
pub enum TruncationStrategy {
|
pub enum TruncationStrategy {
|
||||||
LongestFirst,
|
LongestFirst,
|
||||||
OnlyFirst,
|
OnlyFirst,
|
||||||
@ -28,19 +48,27 @@ pub enum TruncationStrategy {
|
|||||||
pub fn truncate_encodings(
|
pub fn truncate_encodings(
|
||||||
mut encoding: Encoding,
|
mut encoding: Encoding,
|
||||||
mut pair_encoding: Option<Encoding>,
|
mut pair_encoding: Option<Encoding>,
|
||||||
to_remove: usize,
|
params: &TruncationParams,
|
||||||
strategy: TruncationStrategy,
|
|
||||||
stride: usize,
|
|
||||||
) -> Result<(Encoding, Option<Encoding>)> {
|
) -> Result<(Encoding, Option<Encoding>)> {
|
||||||
if to_remove == 0 {
|
if params.max_length == 0 {
|
||||||
return Ok((encoding, pair_encoding));
|
return Ok((encoding, pair_encoding));
|
||||||
}
|
}
|
||||||
|
|
||||||
match strategy {
|
match params.strategy {
|
||||||
TruncationStrategy::LongestFirst => {
|
TruncationStrategy::LongestFirst => {
|
||||||
|
let total_length = encoding.get_ids().len()
|
||||||
|
+ pair_encoding
|
||||||
|
.as_ref()
|
||||||
|
.map(|e| e.get_ids().len())
|
||||||
|
.unwrap_or(0);
|
||||||
|
let to_remove = if total_length > params.max_length {
|
||||||
|
total_length - params.max_length
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
let mut n_first = 0;
|
let mut n_first = 0;
|
||||||
let mut n_second = 0;
|
let mut n_second = 0;
|
||||||
|
|
||||||
for _ in 0..to_remove {
|
for _ in 0..to_remove {
|
||||||
if pair_encoding.is_none()
|
if pair_encoding.is_none()
|
||||||
|| encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len()
|
|| encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len()
|
||||||
@ -51,13 +79,13 @@ pub fn truncate_encodings(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
encoding.truncate(encoding.get_ids().len() - n_first, stride);
|
encoding.truncate(encoding.get_ids().len() - n_first, params.stride);
|
||||||
if let Some(encoding) = pair_encoding.as_mut() {
|
if let Some(encoding) = pair_encoding.as_mut() {
|
||||||
encoding.truncate(encoding.get_ids().len() - n_second, stride);
|
encoding.truncate(encoding.get_ids().len() - n_second, params.stride);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
|
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
|
||||||
let target = if strategy == TruncationStrategy::OnlyFirst {
|
let target = if params.strategy == TruncationStrategy::OnlyFirst {
|
||||||
Ok(&mut encoding)
|
Ok(&mut encoding)
|
||||||
} else if let Some(encoding) = pair_encoding.as_mut() {
|
} else if let Some(encoding) = pair_encoding.as_mut() {
|
||||||
Ok(encoding)
|
Ok(encoding)
|
||||||
@ -65,13 +93,37 @@ pub fn truncate_encodings(
|
|||||||
Err(Box::new(Error::SecondSequenceNotProvided))
|
Err(Box::new(Error::SecondSequenceNotProvided))
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
if target.get_ids().len() <= to_remove {
|
if target.get_ids().len() > params.max_length {
|
||||||
return Err(Box::new(Error::SequenceTooSmall));
|
target.truncate(params.max_length, params.stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
target.truncate(target.get_ids().len() - to_remove, stride);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((encoding, pair_encoding))
|
Ok((encoding, pair_encoding))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn pad_encodings(
|
||||||
|
mut encodings: Vec<Encoding>,
|
||||||
|
params: &PaddingParams,
|
||||||
|
) -> Result<Vec<Encoding>> {
|
||||||
|
if encodings.is_empty() {
|
||||||
|
return Ok(encodings);
|
||||||
|
}
|
||||||
|
|
||||||
|
let pad_length = match params.strategy {
|
||||||
|
PaddingStrategy::Fixed(size) => size,
|
||||||
|
PaddingStrategy::BatchLongest => encodings.iter().map(|e| e.get_ids().len()).max().unwrap(),
|
||||||
|
};
|
||||||
|
|
||||||
|
for encoding in encodings.iter_mut() {
|
||||||
|
encoding.pad(
|
||||||
|
pad_length,
|
||||||
|
params.pad_id,
|
||||||
|
params.pad_type_id,
|
||||||
|
¶ms.pad_token,
|
||||||
|
¶ms.direction,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(encodings)
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user