mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Modify Processor
trait to support chaining. (#1054)
0 modifications yet, everything will consume the vector. Every test should be green without any modifications.
This commit is contained in:
@ -22,16 +22,15 @@ impl tk::PostProcessor for Processor {
|
||||
.added_tokens(is_pair)
|
||||
}
|
||||
|
||||
fn process(
|
||||
fn process_encodings(
|
||||
&self,
|
||||
encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
encodings: Vec<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> tk::Result<Encoding> {
|
||||
) -> tk::Result<Vec<Encoding>> {
|
||||
self.processor
|
||||
.as_ref()
|
||||
.ok_or("Uninitialized PostProcessor")?
|
||||
.process(encoding, pair_encoding, add_special_tokens)
|
||||
.process_encodings(encodings, add_special_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -59,14 +59,13 @@ impl PostProcessor for PyPostProcessor {
|
||||
self.processor.added_tokens(is_pair)
|
||||
}
|
||||
|
||||
fn process(
|
||||
fn process_encodings(
|
||||
&self,
|
||||
encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
encodings: Vec<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> tk::Result<Encoding> {
|
||||
) -> tk::Result<Vec<Encoding>> {
|
||||
self.processor
|
||||
.process(encoding, pair_encoding, add_special_tokens)
|
||||
.process_encodings(encodings, add_special_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -174,20 +174,13 @@ impl PostProcessor for ByteLevel {
|
||||
0
|
||||
}
|
||||
|
||||
fn process(
|
||||
fn process_encodings(
|
||||
&self,
|
||||
mut encoding: Encoding,
|
||||
mut pair_encoding: Option<Encoding>,
|
||||
mut encodings: Vec<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
) -> Result<Vec<Encoding>> {
|
||||
if self.trim_offsets {
|
||||
process_offsets(&mut encoding, self.add_prefix_space);
|
||||
encoding
|
||||
.get_overflowing_mut()
|
||||
.iter_mut()
|
||||
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
|
||||
|
||||
if let Some(encoding) = pair_encoding.as_mut() {
|
||||
for encoding in encodings.iter_mut() {
|
||||
process_offsets(encoding, self.add_prefix_space);
|
||||
encoding
|
||||
.get_overflowing_mut()
|
||||
@ -195,8 +188,7 @@ impl PostProcessor for ByteLevel {
|
||||
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
|
||||
}
|
||||
}
|
||||
|
||||
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)
|
||||
<dyn PostProcessor>::default_process(encodings, add_special_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
||||
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::iter::FromIterator;
|
||||
@ -25,6 +25,12 @@ impl BertProcessing {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum BertProcessorError {
|
||||
#[error("encodings vector length must be either 1 or 2")]
|
||||
InvalidEncodingsVecLength,
|
||||
}
|
||||
|
||||
impl PostProcessor for BertProcessing {
|
||||
fn added_tokens(&self, is_pair: bool) -> usize {
|
||||
if is_pair {
|
||||
@ -34,20 +40,34 @@ impl PostProcessor for BertProcessing {
|
||||
}
|
||||
}
|
||||
|
||||
fn process(
|
||||
fn process_encodings(
|
||||
&self,
|
||||
mut encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
mut encodings: Vec<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
) -> Result<Vec<Encoding>> {
|
||||
if !add_special_tokens {
|
||||
return <dyn PostProcessor>::default_process(
|
||||
encoding,
|
||||
pair_encoding,
|
||||
add_special_tokens,
|
||||
);
|
||||
return Ok(encodings);
|
||||
}
|
||||
|
||||
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
|
||||
1 => (
|
||||
encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
||||
None,
|
||||
),
|
||||
2 => {
|
||||
let pair = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
let encoding = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
(encoding, Some(pair))
|
||||
}
|
||||
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
||||
};
|
||||
|
||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||
let tokens = [
|
||||
@ -166,7 +186,7 @@ impl PostProcessor for BertProcessing {
|
||||
new_encoding.merge_with(new_pair_encoding, false);
|
||||
}
|
||||
|
||||
Ok(new_encoding)
|
||||
Ok(vec![new_encoding])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -33,19 +33,16 @@ impl PostProcessor for PostProcessorWrapper {
|
||||
}
|
||||
}
|
||||
|
||||
fn process(
|
||||
fn process_encodings(
|
||||
&self,
|
||||
encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
encodings: Vec<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
) -> Result<Vec<Encoding>> {
|
||||
match self {
|
||||
Self::Bert(bert) => bert.process(encoding, pair_encoding, add_special_tokens),
|
||||
Self::ByteLevel(bl) => bl.process(encoding, pair_encoding, add_special_tokens),
|
||||
Self::Roberta(roberta) => roberta.process(encoding, pair_encoding, add_special_tokens),
|
||||
Self::Template(template) => {
|
||||
template.process(encoding, pair_encoding, add_special_tokens)
|
||||
}
|
||||
Self::Bert(bert) => bert.process_encodings(encodings, add_special_tokens),
|
||||
Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
|
||||
Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
|
||||
Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::processors::byte_level::process_offsets;
|
||||
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
||||
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::iter::FromIterator;
|
||||
@ -55,20 +55,13 @@ impl PostProcessor for RobertaProcessing {
|
||||
}
|
||||
}
|
||||
|
||||
fn process(
|
||||
fn process_encodings(
|
||||
&self,
|
||||
mut encoding: Encoding,
|
||||
mut pair_encoding: Option<Encoding>,
|
||||
mut encodings: Vec<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
) -> Result<Vec<Encoding>> {
|
||||
if self.trim_offsets {
|
||||
process_offsets(&mut encoding, self.add_prefix_space);
|
||||
encoding
|
||||
.get_overflowing_mut()
|
||||
.iter_mut()
|
||||
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
|
||||
|
||||
if let Some(encoding) = pair_encoding.as_mut() {
|
||||
for encoding in encodings.iter_mut() {
|
||||
process_offsets(encoding, self.add_prefix_space);
|
||||
encoding
|
||||
.get_overflowing_mut()
|
||||
@ -78,13 +71,28 @@ impl PostProcessor for RobertaProcessing {
|
||||
}
|
||||
|
||||
if !add_special_tokens {
|
||||
return <dyn PostProcessor>::default_process(
|
||||
encoding,
|
||||
pair_encoding,
|
||||
add_special_tokens,
|
||||
);
|
||||
return Ok(encodings);
|
||||
}
|
||||
|
||||
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
|
||||
1 => (
|
||||
encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
||||
None,
|
||||
),
|
||||
2 => {
|
||||
let pair = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
let encoding = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
(encoding, Some(pair))
|
||||
}
|
||||
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
||||
};
|
||||
|
||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||
let tokens = [
|
||||
@ -213,7 +221,7 @@ impl PostProcessor for RobertaProcessing {
|
||||
new_encoding.merge_with(new_pair_encoding, false);
|
||||
}
|
||||
|
||||
Ok(new_encoding)
|
||||
Ok(vec![new_encoding])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@
|
||||
//!
|
||||
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
|
||||
//!
|
||||
use crate::{Encoding, PostProcessor, Result};
|
||||
use crate::{tokenizer::ProcessorError, Encoding, PostProcessor, Result};
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@ -630,13 +630,31 @@ impl PostProcessor for TemplateProcessing {
|
||||
}
|
||||
}
|
||||
|
||||
fn process(
|
||||
fn process_encodings(
|
||||
&self,
|
||||
encoding: Encoding,
|
||||
pair: Option<Encoding>,
|
||||
mut encodings: Vec<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
self.apply_template(
|
||||
) -> Result<Vec<Encoding>> {
|
||||
let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
|
||||
1 => (
|
||||
encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
||||
None,
|
||||
),
|
||||
2 => {
|
||||
let pair = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
let encoding = encodings
|
||||
.pop()
|
||||
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||
(encoding, Some(pair))
|
||||
}
|
||||
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
||||
};
|
||||
|
||||
let encoding = self.apply_template(
|
||||
if pair.is_some() {
|
||||
&self.pair.0
|
||||
} else {
|
||||
@ -645,7 +663,8 @@ impl PostProcessor for TemplateProcessing {
|
||||
encoding,
|
||||
pair,
|
||||
add_special_tokens,
|
||||
)
|
||||
)?;
|
||||
Ok(vec![encoding])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,26 +99,50 @@ pub trait PostProcessor {
|
||||
encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding>;
|
||||
) -> Result<Encoding> {
|
||||
let encodings = if let Some(pair_encoding) = pair_encoding {
|
||||
vec![encoding, pair_encoding]
|
||||
} else {
|
||||
vec![encoding]
|
||||
};
|
||||
|
||||
let encodings = self.process_encodings(encodings, add_special_tokens)?;
|
||||
|
||||
Ok(Encoding::merge(encodings, false))
|
||||
}
|
||||
|
||||
/// Process any amount of encodings and returns a series of encoding (might merge them)
|
||||
fn process_encodings(
|
||||
&self,
|
||||
encodings: Vec<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Vec<Encoding>>;
|
||||
}
|
||||
impl dyn PostProcessor {
|
||||
pub fn default_process(
|
||||
mut encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
encodings: Vec<Encoding>,
|
||||
_add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
match pair_encoding {
|
||||
None => Ok(encoding),
|
||||
Some(mut pair) => {
|
||||
encoding.set_sequence_id(0);
|
||||
pair.set_sequence_id(1);
|
||||
encoding.merge_with(pair, false);
|
||||
Ok(encoding)
|
||||
) -> Result<Vec<Encoding>> {
|
||||
match encodings.len() {
|
||||
1 => Ok(encodings),
|
||||
_ => {
|
||||
let mut final_encoding = Encoding::default();
|
||||
for (i, mut encoding) in encodings.into_iter().enumerate() {
|
||||
encoding.set_sequence_id(i);
|
||||
final_encoding.merge_with(encoding, false);
|
||||
}
|
||||
Ok(vec![final_encoding])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ProcessorError {
|
||||
#[error("encodings vector length must be either 1 or 2")]
|
||||
InvalidEncodingsVecLength,
|
||||
}
|
||||
|
||||
/// A `Decoder` changes the raw tokens into its more readable form.
|
||||
pub trait Decoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
@ -895,7 +919,17 @@ where
|
||||
let final_encoding = if let Some(processor) = &self.post_processor {
|
||||
processor.process(encoding, pair_encoding, add_special_tokens)?
|
||||
} else {
|
||||
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)?
|
||||
let encodings = if let Some(pair_encoding) = pair_encoding {
|
||||
vec![encoding, pair_encoding]
|
||||
} else {
|
||||
vec![encoding]
|
||||
};
|
||||
let mut encodings =
|
||||
<dyn PostProcessor>::default_process(encodings, add_special_tokens)?;
|
||||
if encodings.len() != 1 {
|
||||
panic!("We haven't reduced the encodings like we should have");
|
||||
}
|
||||
encodings.pop().unwrap()
|
||||
};
|
||||
|
||||
// 3. Then we pad if needed
|
||||
|
Reference in New Issue
Block a user