Modify Processor trait to support chaining. (#1054)

0 modifications yet, everything will consume the vector.
Every test should be green without any modifications.
This commit is contained in:
Nicolas Patry
2022-08-24 19:49:23 +02:00
committed by GitHub
parent b1c9bc68b5
commit 460bdded80
8 changed files with 149 additions and 81 deletions

View File

@ -22,16 +22,15 @@ impl tk::PostProcessor for Processor {
.added_tokens(is_pair) .added_tokens(is_pair)
} }
fn process( fn process_encodings(
&self, &self,
encoding: Encoding, encodings: Vec<Encoding>,
pair_encoding: Option<Encoding>,
add_special_tokens: bool, add_special_tokens: bool,
) -> tk::Result<Encoding> { ) -> tk::Result<Vec<Encoding>> {
self.processor self.processor
.as_ref() .as_ref()
.ok_or("Uninitialized PostProcessor")? .ok_or("Uninitialized PostProcessor")?
.process(encoding, pair_encoding, add_special_tokens) .process_encodings(encodings, add_special_tokens)
} }
} }

View File

@ -59,14 +59,13 @@ impl PostProcessor for PyPostProcessor {
self.processor.added_tokens(is_pair) self.processor.added_tokens(is_pair)
} }
fn process( fn process_encodings(
&self, &self,
encoding: Encoding, encodings: Vec<Encoding>,
pair_encoding: Option<Encoding>,
add_special_tokens: bool, add_special_tokens: bool,
) -> tk::Result<Encoding> { ) -> tk::Result<Vec<Encoding>> {
self.processor self.processor
.process(encoding, pair_encoding, add_special_tokens) .process_encodings(encodings, add_special_tokens)
} }
} }

View File

@ -174,20 +174,13 @@ impl PostProcessor for ByteLevel {
0 0
} }
fn process( fn process_encodings(
&self, &self,
mut encoding: Encoding, mut encodings: Vec<Encoding>,
mut pair_encoding: Option<Encoding>,
add_special_tokens: bool, add_special_tokens: bool,
) -> Result<Encoding> { ) -> Result<Vec<Encoding>> {
if self.trim_offsets { if self.trim_offsets {
process_offsets(&mut encoding, self.add_prefix_space); for encoding in encodings.iter_mut() {
encoding
.get_overflowing_mut()
.iter_mut()
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
if let Some(encoding) = pair_encoding.as_mut() {
process_offsets(encoding, self.add_prefix_space); process_offsets(encoding, self.add_prefix_space);
encoding encoding
.get_overflowing_mut() .get_overflowing_mut()
@ -195,8 +188,7 @@ impl PostProcessor for ByteLevel {
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space)); .for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
} }
} }
<dyn PostProcessor>::default_process(encodings, add_special_tokens)
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)
} }
} }

View File

@ -1,4 +1,4 @@
use crate::tokenizer::{Encoding, PostProcessor, Result}; use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::iter::FromIterator; use std::iter::FromIterator;
@ -25,6 +25,12 @@ impl BertProcessing {
} }
} }
#[derive(thiserror::Error, Debug)]
pub enum BertProcessorError {
#[error("encodings vector length must be either 1 or 2")]
InvalidEncodingsVecLength,
}
impl PostProcessor for BertProcessing { impl PostProcessor for BertProcessing {
fn added_tokens(&self, is_pair: bool) -> usize { fn added_tokens(&self, is_pair: bool) -> usize {
if is_pair { if is_pair {
@ -34,20 +40,34 @@ impl PostProcessor for BertProcessing {
} }
} }
fn process( fn process_encodings(
&self, &self,
mut encoding: Encoding, mut encodings: Vec<Encoding>,
pair_encoding: Option<Encoding>,
add_special_tokens: bool, add_special_tokens: bool,
) -> Result<Encoding> { ) -> Result<Vec<Encoding>> {
if !add_special_tokens { if !add_special_tokens {
return <dyn PostProcessor>::default_process( return Ok(encodings);
encoding,
pair_encoding,
add_special_tokens,
);
} }
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [ let tokens = [
@ -166,7 +186,7 @@ impl PostProcessor for BertProcessing {
new_encoding.merge_with(new_pair_encoding, false); new_encoding.merge_with(new_pair_encoding, false);
} }
Ok(new_encoding) Ok(vec![new_encoding])
} }
} }

View File

@ -33,19 +33,16 @@ impl PostProcessor for PostProcessorWrapper {
} }
} }
fn process( fn process_encodings(
&self, &self,
encoding: Encoding, encodings: Vec<Encoding>,
pair_encoding: Option<Encoding>,
add_special_tokens: bool, add_special_tokens: bool,
) -> Result<Encoding> { ) -> Result<Vec<Encoding>> {
match self { match self {
Self::Bert(bert) => bert.process(encoding, pair_encoding, add_special_tokens), Self::Bert(bert) => bert.process_encodings(encodings, add_special_tokens),
Self::ByteLevel(bl) => bl.process(encoding, pair_encoding, add_special_tokens), Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
Self::Roberta(roberta) => roberta.process(encoding, pair_encoding, add_special_tokens), Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
Self::Template(template) => { Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
template.process(encoding, pair_encoding, add_special_tokens)
}
} }
} }
} }

View File

@ -1,5 +1,5 @@
use crate::processors::byte_level::process_offsets; use crate::processors::byte_level::process_offsets;
use crate::tokenizer::{Encoding, PostProcessor, Result}; use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::iter::FromIterator; use std::iter::FromIterator;
@ -55,20 +55,13 @@ impl PostProcessor for RobertaProcessing {
} }
} }
fn process( fn process_encodings(
&self, &self,
mut encoding: Encoding, mut encodings: Vec<Encoding>,
mut pair_encoding: Option<Encoding>,
add_special_tokens: bool, add_special_tokens: bool,
) -> Result<Encoding> { ) -> Result<Vec<Encoding>> {
if self.trim_offsets { if self.trim_offsets {
process_offsets(&mut encoding, self.add_prefix_space); for encoding in encodings.iter_mut() {
encoding
.get_overflowing_mut()
.iter_mut()
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
if let Some(encoding) = pair_encoding.as_mut() {
process_offsets(encoding, self.add_prefix_space); process_offsets(encoding, self.add_prefix_space);
encoding encoding
.get_overflowing_mut() .get_overflowing_mut()
@ -78,13 +71,28 @@ impl PostProcessor for RobertaProcessing {
} }
if !add_special_tokens { if !add_special_tokens {
return <dyn PostProcessor>::default_process( return Ok(encodings);
encoding,
pair_encoding,
add_special_tokens,
);
} }
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [ let tokens = [
@ -213,7 +221,7 @@ impl PostProcessor for RobertaProcessing {
new_encoding.merge_with(new_pair_encoding, false); new_encoding.merge_with(new_pair_encoding, false);
} }
Ok(new_encoding) Ok(vec![new_encoding])
} }
} }

View File

@ -55,7 +55,7 @@
//! //!
//! [`TemplateProcessing`]: struct.TemplateProcessing.html //! [`TemplateProcessing`]: struct.TemplateProcessing.html
//! //!
use crate::{Encoding, PostProcessor, Result}; use crate::{tokenizer::ProcessorError, Encoding, PostProcessor, Result};
use itertools::Itertools; use itertools::Itertools;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
@ -630,13 +630,31 @@ impl PostProcessor for TemplateProcessing {
} }
} }
fn process( fn process_encodings(
&self, &self,
encoding: Encoding, mut encodings: Vec<Encoding>,
pair: Option<Encoding>,
add_special_tokens: bool, add_special_tokens: bool,
) -> Result<Encoding> { ) -> Result<Vec<Encoding>> {
self.apply_template( let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};
let encoding = self.apply_template(
if pair.is_some() { if pair.is_some() {
&self.pair.0 &self.pair.0
} else { } else {
@ -645,7 +663,8 @@ impl PostProcessor for TemplateProcessing {
encoding, encoding,
pair, pair,
add_special_tokens, add_special_tokens,
) )?;
Ok(vec![encoding])
} }
} }

View File

@ -99,26 +99,50 @@ pub trait PostProcessor {
encoding: Encoding, encoding: Encoding,
pair_encoding: Option<Encoding>, pair_encoding: Option<Encoding>,
add_special_tokens: bool, add_special_tokens: bool,
) -> Result<Encoding>; ) -> Result<Encoding> {
let encodings = if let Some(pair_encoding) = pair_encoding {
vec![encoding, pair_encoding]
} else {
vec![encoding]
};
let encodings = self.process_encodings(encodings, add_special_tokens)?;
Ok(Encoding::merge(encodings, false))
}
/// Process any amount of encodings and returns a series of encoding (might merge them)
fn process_encodings(
&self,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Vec<Encoding>>;
} }
impl dyn PostProcessor { impl dyn PostProcessor {
pub fn default_process( pub fn default_process(
mut encoding: Encoding, encodings: Vec<Encoding>,
pair_encoding: Option<Encoding>,
_add_special_tokens: bool, _add_special_tokens: bool,
) -> Result<Encoding> { ) -> Result<Vec<Encoding>> {
match pair_encoding { match encodings.len() {
None => Ok(encoding), 1 => Ok(encodings),
Some(mut pair) => { _ => {
encoding.set_sequence_id(0); let mut final_encoding = Encoding::default();
pair.set_sequence_id(1); for (i, mut encoding) in encodings.into_iter().enumerate() {
encoding.merge_with(pair, false); encoding.set_sequence_id(i);
Ok(encoding) final_encoding.merge_with(encoding, false);
}
Ok(vec![final_encoding])
} }
} }
} }
} }
#[derive(thiserror::Error, Debug)]
pub enum ProcessorError {
#[error("encodings vector length must be either 1 or 2")]
InvalidEncodingsVecLength,
}
/// A `Decoder` changes the raw tokens into its more readable form. /// A `Decoder` changes the raw tokens into its more readable form.
pub trait Decoder { pub trait Decoder {
fn decode(&self, tokens: Vec<String>) -> Result<String> { fn decode(&self, tokens: Vec<String>) -> Result<String> {
@ -895,7 +919,17 @@ where
let final_encoding = if let Some(processor) = &self.post_processor { let final_encoding = if let Some(processor) = &self.post_processor {
processor.process(encoding, pair_encoding, add_special_tokens)? processor.process(encoding, pair_encoding, add_special_tokens)?
} else { } else {
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)? let encodings = if let Some(pair_encoding) = pair_encoding {
vec![encoding, pair_encoding]
} else {
vec![encoding]
};
let mut encodings =
<dyn PostProcessor>::default_process(encodings, add_special_tokens)?;
if encodings.len() != 1 {
panic!("We haven't reduced the encodings like we should have");
}
encodings.pop().unwrap()
}; };
// 3. Then we pad if needed // 3. Then we pad if needed