mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Modify Processor
trait to support chaining. (#1054)
0 modifications yet, everything will consume the vector. Every test should be green without any modifications.
This commit is contained in:
@ -22,16 +22,15 @@ impl tk::PostProcessor for Processor {
|
|||||||
.added_tokens(is_pair)
|
.added_tokens(is_pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process(
|
fn process_encodings(
|
||||||
&self,
|
&self,
|
||||||
encoding: Encoding,
|
encodings: Vec<Encoding>,
|
||||||
pair_encoding: Option<Encoding>,
|
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> tk::Result<Encoding> {
|
) -> tk::Result<Vec<Encoding>> {
|
||||||
self.processor
|
self.processor
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or("Uninitialized PostProcessor")?
|
.ok_or("Uninitialized PostProcessor")?
|
||||||
.process(encoding, pair_encoding, add_special_tokens)
|
.process_encodings(encodings, add_special_tokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,14 +59,13 @@ impl PostProcessor for PyPostProcessor {
|
|||||||
self.processor.added_tokens(is_pair)
|
self.processor.added_tokens(is_pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process(
|
fn process_encodings(
|
||||||
&self,
|
&self,
|
||||||
encoding: Encoding,
|
encodings: Vec<Encoding>,
|
||||||
pair_encoding: Option<Encoding>,
|
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> tk::Result<Encoding> {
|
) -> tk::Result<Vec<Encoding>> {
|
||||||
self.processor
|
self.processor
|
||||||
.process(encoding, pair_encoding, add_special_tokens)
|
.process_encodings(encodings, add_special_tokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,20 +174,13 @@ impl PostProcessor for ByteLevel {
|
|||||||
0
|
0
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process(
|
fn process_encodings(
|
||||||
&self,
|
&self,
|
||||||
mut encoding: Encoding,
|
mut encodings: Vec<Encoding>,
|
||||||
mut pair_encoding: Option<Encoding>,
|
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> Result<Encoding> {
|
) -> Result<Vec<Encoding>> {
|
||||||
if self.trim_offsets {
|
if self.trim_offsets {
|
||||||
process_offsets(&mut encoding, self.add_prefix_space);
|
for encoding in encodings.iter_mut() {
|
||||||
encoding
|
|
||||||
.get_overflowing_mut()
|
|
||||||
.iter_mut()
|
|
||||||
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
|
|
||||||
|
|
||||||
if let Some(encoding) = pair_encoding.as_mut() {
|
|
||||||
process_offsets(encoding, self.add_prefix_space);
|
process_offsets(encoding, self.add_prefix_space);
|
||||||
encoding
|
encoding
|
||||||
.get_overflowing_mut()
|
.get_overflowing_mut()
|
||||||
@ -195,8 +188,7 @@ impl PostProcessor for ByteLevel {
|
|||||||
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
|
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
<dyn PostProcessor>::default_process(encodings, add_special_tokens)
|
||||||
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::iter::FromIterator;
|
use std::iter::FromIterator;
|
||||||
@ -25,6 +25,12 @@ impl BertProcessing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum BertProcessorError {
|
||||||
|
#[error("encodings vector length must be either 1 or 2")]
|
||||||
|
InvalidEncodingsVecLength,
|
||||||
|
}
|
||||||
|
|
||||||
impl PostProcessor for BertProcessing {
|
impl PostProcessor for BertProcessing {
|
||||||
fn added_tokens(&self, is_pair: bool) -> usize {
|
fn added_tokens(&self, is_pair: bool) -> usize {
|
||||||
if is_pair {
|
if is_pair {
|
||||||
@ -34,20 +40,34 @@ impl PostProcessor for BertProcessing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process(
|
fn process_encodings(
|
||||||
&self,
|
&self,
|
||||||
mut encoding: Encoding,
|
mut encodings: Vec<Encoding>,
|
||||||
pair_encoding: Option<Encoding>,
|
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> Result<Encoding> {
|
) -> Result<Vec<Encoding>> {
|
||||||
if !add_special_tokens {
|
if !add_special_tokens {
|
||||||
return <dyn PostProcessor>::default_process(
|
return Ok(encodings);
|
||||||
encoding,
|
|
||||||
pair_encoding,
|
|
||||||
add_special_tokens,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
|
||||||
|
1 => (
|
||||||
|
encodings
|
||||||
|
.pop()
|
||||||
|
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
2 => {
|
||||||
|
let pair = encodings
|
||||||
|
.pop()
|
||||||
|
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||||
|
let encoding = encodings
|
||||||
|
.pop()
|
||||||
|
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||||
|
(encoding, Some(pair))
|
||||||
|
}
|
||||||
|
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
||||||
|
};
|
||||||
|
|
||||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||||
let tokens = [
|
let tokens = [
|
||||||
@ -166,7 +186,7 @@ impl PostProcessor for BertProcessing {
|
|||||||
new_encoding.merge_with(new_pair_encoding, false);
|
new_encoding.merge_with(new_pair_encoding, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(new_encoding)
|
Ok(vec![new_encoding])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,19 +33,16 @@ impl PostProcessor for PostProcessorWrapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process(
|
fn process_encodings(
|
||||||
&self,
|
&self,
|
||||||
encoding: Encoding,
|
encodings: Vec<Encoding>,
|
||||||
pair_encoding: Option<Encoding>,
|
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> Result<Encoding> {
|
) -> Result<Vec<Encoding>> {
|
||||||
match self {
|
match self {
|
||||||
Self::Bert(bert) => bert.process(encoding, pair_encoding, add_special_tokens),
|
Self::Bert(bert) => bert.process_encodings(encodings, add_special_tokens),
|
||||||
Self::ByteLevel(bl) => bl.process(encoding, pair_encoding, add_special_tokens),
|
Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
|
||||||
Self::Roberta(roberta) => roberta.process(encoding, pair_encoding, add_special_tokens),
|
Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
|
||||||
Self::Template(template) => {
|
Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
|
||||||
template.process(encoding, pair_encoding, add_special_tokens)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::processors::byte_level::process_offsets;
|
use crate::processors::byte_level::process_offsets;
|
||||||
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::iter::FromIterator;
|
use std::iter::FromIterator;
|
||||||
@ -55,20 +55,13 @@ impl PostProcessor for RobertaProcessing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process(
|
fn process_encodings(
|
||||||
&self,
|
&self,
|
||||||
mut encoding: Encoding,
|
mut encodings: Vec<Encoding>,
|
||||||
mut pair_encoding: Option<Encoding>,
|
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> Result<Encoding> {
|
) -> Result<Vec<Encoding>> {
|
||||||
if self.trim_offsets {
|
if self.trim_offsets {
|
||||||
process_offsets(&mut encoding, self.add_prefix_space);
|
for encoding in encodings.iter_mut() {
|
||||||
encoding
|
|
||||||
.get_overflowing_mut()
|
|
||||||
.iter_mut()
|
|
||||||
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
|
|
||||||
|
|
||||||
if let Some(encoding) = pair_encoding.as_mut() {
|
|
||||||
process_offsets(encoding, self.add_prefix_space);
|
process_offsets(encoding, self.add_prefix_space);
|
||||||
encoding
|
encoding
|
||||||
.get_overflowing_mut()
|
.get_overflowing_mut()
|
||||||
@ -78,13 +71,28 @@ impl PostProcessor for RobertaProcessing {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !add_special_tokens {
|
if !add_special_tokens {
|
||||||
return <dyn PostProcessor>::default_process(
|
return Ok(encodings);
|
||||||
encoding,
|
|
||||||
pair_encoding,
|
|
||||||
add_special_tokens,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
|
||||||
|
1 => (
|
||||||
|
encodings
|
||||||
|
.pop()
|
||||||
|
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
2 => {
|
||||||
|
let pair = encodings
|
||||||
|
.pop()
|
||||||
|
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||||
|
let encoding = encodings
|
||||||
|
.pop()
|
||||||
|
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||||
|
(encoding, Some(pair))
|
||||||
|
}
|
||||||
|
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
||||||
|
};
|
||||||
|
|
||||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
|
||||||
let tokens = [
|
let tokens = [
|
||||||
@ -213,7 +221,7 @@ impl PostProcessor for RobertaProcessing {
|
|||||||
new_encoding.merge_with(new_pair_encoding, false);
|
new_encoding.merge_with(new_pair_encoding, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(new_encoding)
|
Ok(vec![new_encoding])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@
|
|||||||
//!
|
//!
|
||||||
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
|
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
|
||||||
//!
|
//!
|
||||||
use crate::{Encoding, PostProcessor, Result};
|
use crate::{tokenizer::ProcessorError, Encoding, PostProcessor, Result};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
@ -630,13 +630,31 @@ impl PostProcessor for TemplateProcessing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process(
|
fn process_encodings(
|
||||||
&self,
|
&self,
|
||||||
encoding: Encoding,
|
mut encodings: Vec<Encoding>,
|
||||||
pair: Option<Encoding>,
|
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> Result<Encoding> {
|
) -> Result<Vec<Encoding>> {
|
||||||
self.apply_template(
|
let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
|
||||||
|
1 => (
|
||||||
|
encodings
|
||||||
|
.pop()
|
||||||
|
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
2 => {
|
||||||
|
let pair = encodings
|
||||||
|
.pop()
|
||||||
|
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||||
|
let encoding = encodings
|
||||||
|
.pop()
|
||||||
|
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
|
||||||
|
(encoding, Some(pair))
|
||||||
|
}
|
||||||
|
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoding = self.apply_template(
|
||||||
if pair.is_some() {
|
if pair.is_some() {
|
||||||
&self.pair.0
|
&self.pair.0
|
||||||
} else {
|
} else {
|
||||||
@ -645,7 +663,8 @@ impl PostProcessor for TemplateProcessing {
|
|||||||
encoding,
|
encoding,
|
||||||
pair,
|
pair,
|
||||||
add_special_tokens,
|
add_special_tokens,
|
||||||
)
|
)?;
|
||||||
|
Ok(vec![encoding])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,26 +99,50 @@ pub trait PostProcessor {
|
|||||||
encoding: Encoding,
|
encoding: Encoding,
|
||||||
pair_encoding: Option<Encoding>,
|
pair_encoding: Option<Encoding>,
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> Result<Encoding>;
|
) -> Result<Encoding> {
|
||||||
|
let encodings = if let Some(pair_encoding) = pair_encoding {
|
||||||
|
vec![encoding, pair_encoding]
|
||||||
|
} else {
|
||||||
|
vec![encoding]
|
||||||
|
};
|
||||||
|
|
||||||
|
let encodings = self.process_encodings(encodings, add_special_tokens)?;
|
||||||
|
|
||||||
|
Ok(Encoding::merge(encodings, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process any amount of encodings and returns a series of encoding (might merge them)
|
||||||
|
fn process_encodings(
|
||||||
|
&self,
|
||||||
|
encodings: Vec<Encoding>,
|
||||||
|
add_special_tokens: bool,
|
||||||
|
) -> Result<Vec<Encoding>>;
|
||||||
}
|
}
|
||||||
impl dyn PostProcessor {
|
impl dyn PostProcessor {
|
||||||
pub fn default_process(
|
pub fn default_process(
|
||||||
mut encoding: Encoding,
|
encodings: Vec<Encoding>,
|
||||||
pair_encoding: Option<Encoding>,
|
|
||||||
_add_special_tokens: bool,
|
_add_special_tokens: bool,
|
||||||
) -> Result<Encoding> {
|
) -> Result<Vec<Encoding>> {
|
||||||
match pair_encoding {
|
match encodings.len() {
|
||||||
None => Ok(encoding),
|
1 => Ok(encodings),
|
||||||
Some(mut pair) => {
|
_ => {
|
||||||
encoding.set_sequence_id(0);
|
let mut final_encoding = Encoding::default();
|
||||||
pair.set_sequence_id(1);
|
for (i, mut encoding) in encodings.into_iter().enumerate() {
|
||||||
encoding.merge_with(pair, false);
|
encoding.set_sequence_id(i);
|
||||||
Ok(encoding)
|
final_encoding.merge_with(encoding, false);
|
||||||
|
}
|
||||||
|
Ok(vec![final_encoding])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum ProcessorError {
|
||||||
|
#[error("encodings vector length must be either 1 or 2")]
|
||||||
|
InvalidEncodingsVecLength,
|
||||||
|
}
|
||||||
|
|
||||||
/// A `Decoder` changes the raw tokens into its more readable form.
|
/// A `Decoder` changes the raw tokens into its more readable form.
|
||||||
pub trait Decoder {
|
pub trait Decoder {
|
||||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||||
@ -895,7 +919,17 @@ where
|
|||||||
let final_encoding = if let Some(processor) = &self.post_processor {
|
let final_encoding = if let Some(processor) = &self.post_processor {
|
||||||
processor.process(encoding, pair_encoding, add_special_tokens)?
|
processor.process(encoding, pair_encoding, add_special_tokens)?
|
||||||
} else {
|
} else {
|
||||||
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)?
|
let encodings = if let Some(pair_encoding) = pair_encoding {
|
||||||
|
vec![encoding, pair_encoding]
|
||||||
|
} else {
|
||||||
|
vec![encoding]
|
||||||
|
};
|
||||||
|
let mut encodings =
|
||||||
|
<dyn PostProcessor>::default_process(encodings, add_special_tokens)?;
|
||||||
|
if encodings.len() != 1 {
|
||||||
|
panic!("We haven't reduced the encodings like we should have");
|
||||||
|
}
|
||||||
|
encodings.pop().unwrap()
|
||||||
};
|
};
|
||||||
|
|
||||||
// 3. Then we pad if needed
|
// 3. Then we pad if needed
|
||||||
|
Reference in New Issue
Block a user