Modify Processor trait to support chaining. (#1054)

0 modifications yet, everything will consume the vector.
Every test should be green without any modifications.
This commit is contained in:
Nicolas Patry
2022-08-24 19:49:23 +02:00
committed by GitHub
parent b1c9bc68b5
commit 460bdded80
8 changed files with 149 additions and 81 deletions

View File

@ -22,16 +22,15 @@ impl tk::PostProcessor for Processor {
.added_tokens(is_pair)
}
fn process(
fn process_encodings(
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> tk::Result<Encoding> {
) -> tk::Result<Vec<Encoding>> {
self.processor
.as_ref()
.ok_or("Uninitialized PostProcessor")?
.process(encoding, pair_encoding, add_special_tokens)
.process_encodings(encodings, add_special_tokens)
}
}

View File

@ -59,14 +59,13 @@ impl PostProcessor for PyPostProcessor {
self.processor.added_tokens(is_pair)
}
fn process(
fn process_encodings(
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> tk::Result<Encoding> {
) -> tk::Result<Vec<Encoding>> {
self.processor
.process(encoding, pair_encoding, add_special_tokens)
.process_encodings(encodings, add_special_tokens)
}
}

View File

@ -174,20 +174,13 @@ impl PostProcessor for ByteLevel {
0
}
fn process(
fn process_encodings(
&self,
mut encoding: Encoding,
mut pair_encoding: Option<Encoding>,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
) -> Result<Vec<Encoding>> {
if self.trim_offsets {
process_offsets(&mut encoding, self.add_prefix_space);
encoding
.get_overflowing_mut()
.iter_mut()
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
if let Some(encoding) = pair_encoding.as_mut() {
for encoding in encodings.iter_mut() {
process_offsets(encoding, self.add_prefix_space);
encoding
.get_overflowing_mut()
@ -195,8 +188,7 @@ impl PostProcessor for ByteLevel {
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
}
}
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)
<dyn PostProcessor>::default_process(encodings, add_special_tokens)
}
}

View File

@ -1,4 +1,4 @@
use crate::tokenizer::{Encoding, PostProcessor, Result};
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::iter::FromIterator;
@ -25,6 +25,12 @@ impl BertProcessing {
}
}
#[derive(thiserror::Error, Debug)]
pub enum BertProcessorError {
#[error("encodings vector length must be either 1 or 2")]
InvalidEncodingsVecLength,
}
impl PostProcessor for BertProcessing {
fn added_tokens(&self, is_pair: bool) -> usize {
if is_pair {
@ -34,20 +40,34 @@ impl PostProcessor for BertProcessing {
}
}
fn process(
fn process_encodings(
&self,
mut encoding: Encoding,
pair_encoding: Option<Encoding>,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
) -> Result<Vec<Encoding>> {
if !add_special_tokens {
return <dyn PostProcessor>::default_process(
encoding,
pair_encoding,
add_special_tokens,
);
return Ok(encodings);
}
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
@ -166,7 +186,7 @@ impl PostProcessor for BertProcessing {
new_encoding.merge_with(new_pair_encoding, false);
}
Ok(new_encoding)
Ok(vec![new_encoding])
}
}

View File

@ -33,19 +33,16 @@ impl PostProcessor for PostProcessorWrapper {
}
}
fn process(
fn process_encodings(
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
) -> Result<Vec<Encoding>> {
match self {
Self::Bert(bert) => bert.process(encoding, pair_encoding, add_special_tokens),
Self::ByteLevel(bl) => bl.process(encoding, pair_encoding, add_special_tokens),
Self::Roberta(roberta) => roberta.process(encoding, pair_encoding, add_special_tokens),
Self::Template(template) => {
template.process(encoding, pair_encoding, add_special_tokens)
}
Self::Bert(bert) => bert.process_encodings(encodings, add_special_tokens),
Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
}
}
}

View File

@ -1,5 +1,5 @@
use crate::processors::byte_level::process_offsets;
use crate::tokenizer::{Encoding, PostProcessor, Result};
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::iter::FromIterator;
@ -55,20 +55,13 @@ impl PostProcessor for RobertaProcessing {
}
}
fn process(
fn process_encodings(
&self,
mut encoding: Encoding,
mut pair_encoding: Option<Encoding>,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
) -> Result<Vec<Encoding>> {
if self.trim_offsets {
process_offsets(&mut encoding, self.add_prefix_space);
encoding
.get_overflowing_mut()
.iter_mut()
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
if let Some(encoding) = pair_encoding.as_mut() {
for encoding in encodings.iter_mut() {
process_offsets(encoding, self.add_prefix_space);
encoding
.get_overflowing_mut()
@ -78,13 +71,28 @@ impl PostProcessor for RobertaProcessing {
}
if !add_special_tokens {
return <dyn PostProcessor>::default_process(
encoding,
pair_encoding,
add_special_tokens,
);
return Ok(encodings);
}
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
@ -213,7 +221,7 @@ impl PostProcessor for RobertaProcessing {
new_encoding.merge_with(new_pair_encoding, false);
}
Ok(new_encoding)
Ok(vec![new_encoding])
}
}

View File

@ -55,7 +55,7 @@
//!
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
//!
use crate::{Encoding, PostProcessor, Result};
use crate::{tokenizer::ProcessorError, Encoding, PostProcessor, Result};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
@ -630,13 +630,31 @@ impl PostProcessor for TemplateProcessing {
}
}
fn process(
fn process_encodings(
&self,
encoding: Encoding,
pair: Option<Encoding>,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
self.apply_template(
) -> Result<Vec<Encoding>> {
let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};
let encoding = self.apply_template(
if pair.is_some() {
&self.pair.0
} else {
@ -645,7 +663,8 @@ impl PostProcessor for TemplateProcessing {
encoding,
pair,
add_special_tokens,
)
)?;
Ok(vec![encoding])
}
}

View File

@ -99,26 +99,50 @@ pub trait PostProcessor {
encoding: Encoding,
pair_encoding: Option<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding>;
) -> Result<Encoding> {
let encodings = if let Some(pair_encoding) = pair_encoding {
vec![encoding, pair_encoding]
} else {
vec![encoding]
};
let encodings = self.process_encodings(encodings, add_special_tokens)?;
Ok(Encoding::merge(encodings, false))
}
/// Process any amount of encodings and returns a series of encoding (might merge them)
fn process_encodings(
&self,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Vec<Encoding>>;
}
impl dyn PostProcessor {
pub fn default_process(
mut encoding: Encoding,
pair_encoding: Option<Encoding>,
encodings: Vec<Encoding>,
_add_special_tokens: bool,
) -> Result<Encoding> {
match pair_encoding {
None => Ok(encoding),
Some(mut pair) => {
encoding.set_sequence_id(0);
pair.set_sequence_id(1);
encoding.merge_with(pair, false);
Ok(encoding)
) -> Result<Vec<Encoding>> {
match encodings.len() {
1 => Ok(encodings),
_ => {
let mut final_encoding = Encoding::default();
for (i, mut encoding) in encodings.into_iter().enumerate() {
encoding.set_sequence_id(i);
final_encoding.merge_with(encoding, false);
}
Ok(vec![final_encoding])
}
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum ProcessorError {
#[error("encodings vector length must be either 1 or 2")]
InvalidEncodingsVecLength,
}
/// A `Decoder` changes the raw tokens into its more readable form.
pub trait Decoder {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
@ -895,7 +919,17 @@ where
let final_encoding = if let Some(processor) = &self.post_processor {
processor.process(encoding, pair_encoding, add_special_tokens)?
} else {
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)?
let encodings = if let Some(pair_encoding) = pair_encoding {
vec![encoding, pair_encoding]
} else {
vec![encoding]
};
let mut encodings =
<dyn PostProcessor>::default_process(encodings, add_special_tokens)?;
if encodings.len() != 1 {
panic!("We haven't reduced the encodings like we should have");
}
encodings.pop().unwrap()
};
// 3. Then we pad if needed