Adding Sequence for PostProcessor. (#1052)

* Adding `Sequence` for `PostProcessor`.

* Fixing node? Writing in the dark here, don't have Python2.7

* `undefined` is not accepted.

* Other test.
This commit is contained in:
Nicolas Patry
2022-08-25 14:50:06 +02:00
committed by GitHub
parent 37f7bae0f7
commit 06025e4ca1
12 changed files with 344 additions and 7 deletions

View File

@ -54,3 +54,11 @@ export function templateProcessing(
pair?: string,
specialTokens?: [string, number][]
): PostProcessor;
/**
* Instantiate a new SequenceProcessing.
*
* @param PostProcessor[] The list of Processors to use
* @since 0.13.0
*/
export function sequenceProcessing(processors: PostProcessor[]): PostProcessor;

View File

@ -5,4 +5,5 @@ module.exports = {
byteLevelProcessing: native.processors_ByteLevel,
robertaProcessing: native.processors_RobertaProcessing,
templateProcessing: native.processors_TemplateProcessing,
sequenceProcessing: native.processors_Sequence,
};

View File

@ -4,6 +4,7 @@ import {
bertProcessing,
byteLevelProcessing,
robertaProcessing,
sequenceProcessing,
templateProcessing,
} from "./post-processors";
@ -81,3 +82,14 @@ describe("templateProcessing", () => {
expect(processor.constructor.name).toEqual("Processor");
});
});
describe("sequenceProcessing", () => {
it("accepts `PostProcessor[]` as first parameter", () => {
const template = templateProcessing("[CLS] $A [SEP]", "[CLS] $A [SEP] $B:1 [SEP]:1", [
["[CLS]", 1],
["[SEP]", 2],
]);
const bytelevel = byteLevelProcessing(true);
expect(sequenceProcessing([bytelevel, template])).toBeDefined();
});
});

View File

@ -129,6 +129,33 @@ fn template_processing(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
Ok(js_processor)
}
/// sequence(processors: List[Processor])
fn sequence(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
let processors = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
let mut sequence = Vec::with_capacity(processors.len());
processors.into_iter().try_for_each(|processor| {
match processor.downcast::<JsPostProcessor>().or_throw(&mut cx) {
Ok(processor) => {
let guard = cx.lock();
if let Some(processor_arc) = &processor.borrow(&guard).processor {
let processor: PostProcessorWrapper = (**processor_arc).clone();
sequence.push(processor);
}
Ok(())
}
Err(e) => Err(e),
}
})?;
let mut pretok = JsPostProcessor::new::<_, JsPostProcessor, _>(&mut cx, vec![])?;
let guard = cx.lock();
pretok.borrow_mut(&guard).processor = Some(Arc::new(PostProcessorWrapper::Sequence(
tk::processors::sequence::Sequence::new(sequence),
)));
Ok(pretok)
}
/// Register everything here
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_BertProcessing", prefix), bert_processing)?;
@ -138,5 +165,6 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
&format!("{}_TemplateProcessing", prefix),
template_processing,
)?;
m.export_function(&format!("{}_Sequence", prefix), sequence)?;
Ok(())
}

View File

@ -5,4 +5,5 @@ PostProcessor = processors.PostProcessor
BertProcessing = processors.BertProcessing
ByteLevel = processors.ByteLevel
RobertaProcessing = processors.RobertaProcessing
Sequence = processors.Sequence
TemplateProcessing = processors.TemplateProcessing

View File

@ -193,6 +193,48 @@ class RobertaProcessing(PostProcessor):
"""
pass
class Sequence(PostProcessor):
"""
Sequence Processor
Args:
processors (:obj:`List[PostProcessor]`)
The processors that need to be chained
"""
def __init__(self, processors):
pass
def num_special_tokens_to_add(self, is_pair):
"""
Return the number of special tokens that would be added for single/pair sentences.
Args:
is_pair (:obj:`bool`):
Whether the input would be a pair of sequences
Returns:
:obj:`int`: The number of tokens to add
"""
pass
def process(self, encoding, pair=None, add_special_tokens=True):
"""
Post-process the given encodings, generating the final one
Args:
encoding (:class:`~tokenizers.Encoding`):
The encoding for the first sequence
pair (:class:`~tokenizers.Encoding`, `optional`):
The encoding for the pair sequence
add_special_tokens (:obj:`bool`):
Whether to add the special tokens
Return:
:class:`~tokenizers.Encoding`: The final encoding
"""
pass
class TemplateProcessing(PostProcessor):
"""
Provides a way to specify templates in order to add the special tokens to each

View File

@ -104,6 +104,7 @@ fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<processors::PyRobertaProcessing>()?;
m.add_class::<processors::PyByteLevel>()?;
m.add_class::<processors::PyTemplateProcessing>()?;
m.add_class::<processors::PySequence>()?;
Ok(())
}

View File

@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize};
use tk::processors::bert::BertProcessing;
use tk::processors::byte_level::ByteLevel;
use tk::processors::roberta::RobertaProcessing;
use tk::processors::sequence::Sequence;
use tk::processors::template::{SpecialToken, Template};
use tk::processors::PostProcessorWrapper;
use tk::{Encoding, PostProcessor};
@ -50,6 +51,7 @@ impl PyPostProcessor {
PostProcessorWrapper::Template(_) => {
Py::new(py, (PyTemplateProcessing {}, base))?.into_py(py)
}
PostProcessorWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
})
}
}
@ -414,6 +416,37 @@ impl PyTemplateProcessing {
}
}
/// Sequence Processor
///
/// Args:
/// processors (:obj:`List[PostProcessor]`)
/// The processors that need to be chained
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name = "Sequence")]
#[pyo3(text_signature = "(self, processors)")]
pub struct PySequence {}
#[pymethods]
impl PySequence {
#[new]
#[args(processors)]
fn new(processors_py: &PyList) -> (Self, PyPostProcessor) {
let mut processors: Vec<PostProcessorWrapper> = Vec::with_capacity(processors_py.len());
for n in processors_py.iter() {
let processor: PyRef<PyPostProcessor> = n.extract().unwrap();
let processor = processor.processor.as_ref();
processors.push(processor.clone());
}
let sequence_processor = Sequence::new(processors);
(
PySequence {},
PyPostProcessor::new(Arc::new(PostProcessorWrapper::Sequence(sequence_processor))),
)
}
fn __getnewargs__<'p>(&self, py: Python<'p>) -> &'p PyTuple {
PyTuple::new(py, &[PyList::empty(py)])
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;

View File

@ -13,6 +13,7 @@ from tokenizers.processors import (
RobertaProcessing,
ByteLevel,
TemplateProcessing,
Sequence,
)
@ -179,3 +180,49 @@ class TestTemplateProcessing:
tokenizer.post_processor = self.get_roberta()
template = tokenizer.encode("my name is john", "pair")
assert original.ids == template.ids
class TestSequenceProcessing:
def test_sequence_processing(self):
assert Sequence([]) is not None
assert Sequence([ByteLevel()]) is not None
assert isinstance(Sequence([]), PostProcessor)
assert isinstance(Sequence([]), Sequence)
serialized = pickle.dumps(Sequence([]))
assert isinstance(pickle.loads(serialized), Sequence)
def test_post_process(self):
byte_level = ByteLevel(trim_offsets=True)
template = TemplateProcessing(
single=["[CLS]", "$0", "[SEP]"],
pair=["[CLS]:0", "$A", "[SEP]:0", "$B:1", "[SEP]:1"],
special_tokens=[("[CLS]", 1), ("[SEP]", 0)],
)
tokenizer = Tokenizer(BPE())
tokenizer.add_special_tokens(["[SEP]", "[CLS]"])
tokenizer.add_tokens(["my", "name", "is", "Ġjohn", "pair"])
tokenizer.post_processor = template
# Before the sequence
original = tokenizer.encode("my name is Ġjohn")
assert original.ids == [1, 2, 3, 4, 5, 0]
assert original.type_ids == [0, 0, 0, 0, 0, 0]
assert original.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (11, 16), (0, 0)]
pair = tokenizer.encode("my name is Ġjohn", "pair")
# assert pair.ids == [1, 2, 3, 4, 5, 0, 6, 0]
assert pair.type_ids == [0, 0, 0, 0, 0, 0, 1, 1]
assert pair.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (11, 16), (0, 0), (0, 4), (0, 0)]
processor = Sequence([byte_level, template])
tokenizer.post_processor = processor
original = tokenizer.encode("my name is Ġjohn")
assert original.ids == [1, 2, 3, 4, 5, 0]
assert original.type_ids == [0, 0, 0, 0, 0, 0]
# Offsets ARE trimmed
assert original.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (12, 16), (0, 0)]
pair = tokenizer.encode("my name is Ġjohn", "pair")
# assert pair.ids == [1, 2, 3, 4, 5, 0, 6, 0]
assert pair.type_ids == [0, 0, 0, 0, 0, 0, 1, 1]
assert pair.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (12, 16), (0, 0), (0, 4), (0, 0)]

View File

@ -177,7 +177,7 @@ impl PostProcessor for ByteLevel {
fn process_encodings(
&self,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
_add_special_tokens: bool,
) -> Result<Vec<Encoding>> {
if self.trim_offsets {
for encoding in encodings.iter_mut() {
@ -188,7 +188,11 @@ impl PostProcessor for ByteLevel {
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
}
}
<dyn PostProcessor>::default_process(encodings, add_special_tokens)
for (i, encoding) in encodings.iter_mut().enumerate() {
encoding.set_sequence_id(i);
}
Ok(encodings)
//<dyn PostProcessor>::default_process(encodings, add_special_tokens)
}
}
@ -493,7 +497,7 @@ mod tests {
vec![],
vec![],
vec![],
HashMap::new(),
HashMap::from_iter(vec![(0, 0..5)]),
);
let bytelevel = ByteLevel::default().trim_offsets(true);
@ -502,8 +506,8 @@ mod tests {
bytelevel.process(start.clone(), None, false).unwrap()
);
let mut pair_expected = Encoding::new(
vec![0; 5],
let pair_expected = Encoding::new(
vec![0; 10],
vec![],
vec![
"Ġ".into(),
@ -511,15 +515,30 @@ mod tests {
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
"Ġ".into(),
"ĠĠĠĠHelloĠĠ".into(),
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
],
vec![],
vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)],
vec![
(0, 0),
(4, 9),
(13, 18),
(18, 23),
(29, 29),
(0, 0),
(4, 9),
(13, 18),
(18, 23),
(29, 29),
],
vec![],
vec![],
vec![],
HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
);
pair_expected.merge_with(expected, false);
assert_eq!(
pair_expected,
bytelevel

View File

@ -1,5 +1,6 @@
pub mod bert;
pub mod roberta;
pub mod sequence;
pub mod template;
// Re-export these as processors
@ -10,6 +11,7 @@ use serde::{Deserialize, Serialize};
use crate::pre_tokenizers::byte_level::ByteLevel;
use crate::processors::bert::BertProcessing;
use crate::processors::roberta::RobertaProcessing;
use crate::processors::sequence::Sequence;
use crate::processors::template::TemplateProcessing;
use crate::{Encoding, PostProcessor, Result};
@ -21,6 +23,7 @@ pub enum PostProcessorWrapper {
Bert(BertProcessing),
ByteLevel(ByteLevel),
Template(TemplateProcessing),
Sequence(Sequence),
}
impl PostProcessor for PostProcessorWrapper {
@ -30,6 +33,7 @@ impl PostProcessor for PostProcessorWrapper {
Self::ByteLevel(bl) => bl.added_tokens(is_pair),
Self::Roberta(roberta) => roberta.added_tokens(is_pair),
Self::Template(template) => template.added_tokens(is_pair),
Self::Sequence(bl) => bl.added_tokens(is_pair),
}
}
@ -43,6 +47,7 @@ impl PostProcessor for PostProcessorWrapper {
Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
Self::Sequence(bl) => bl.process_encodings(encodings, add_special_tokens),
}
}
}
@ -51,6 +56,7 @@ impl_enum_from!(BertProcessing, PostProcessorWrapper, Bert);
impl_enum_from!(ByteLevel, PostProcessorWrapper, ByteLevel);
impl_enum_from!(RobertaProcessing, PostProcessorWrapper, Roberta);
impl_enum_from!(TemplateProcessing, PostProcessorWrapper, Template);
impl_enum_from!(Sequence, PostProcessorWrapper, Sequence);
#[cfg(test)]
mod tests {

View File

@ -0,0 +1,139 @@
use crate::processors::PostProcessorWrapper;
use crate::tokenizer::{Encoding, PostProcessor, Result};
use crate::utils::macro_rules_attribute;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct Sequence {
processors: Vec<PostProcessorWrapper>,
}
impl Sequence {
pub fn new(processors: Vec<PostProcessorWrapper>) -> Self {
Self { processors }
}
}
impl PostProcessor for Sequence {
fn added_tokens(&self, is_pair: bool) -> usize {
self.processors
.iter()
.map(|p| p.added_tokens(is_pair))
.sum::<usize>()
}
fn process_encodings(
&self,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Vec<Encoding>> {
for processor in &self.processors {
encodings = processor.process_encodings(encodings, add_special_tokens)?;
}
Ok(encodings)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::processors::{ByteLevel, PostProcessorWrapper};
use crate::tokenizer::{Encoding, PostProcessor};
use std::collections::HashMap;
use std::iter::FromIterator;
#[test]
fn process_chain() {
let start = Encoding::new(
vec![0; 5],
vec![],
vec![
"Ġ".into(),
"ĠĠĠĠHelloĠĠ".into(),
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
],
vec![],
vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)],
vec![],
vec![],
vec![],
HashMap::new(),
);
let bytelevel = ByteLevel::default().trim_offsets(true);
let sequence = Sequence::new(vec![PostProcessorWrapper::ByteLevel(bytelevel)]);
let expected = Encoding::new(
vec![0; 5],
vec![],
vec![
"Ġ".into(),
"ĠĠĠĠHelloĠĠ".into(),
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
],
vec![],
vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)],
vec![],
vec![],
vec![],
HashMap::from_iter(vec![(0, 0..5)]),
);
assert_eq!(
expected,
bytelevel.process(start.clone(), None, false).unwrap()
);
assert_eq!(
expected,
sequence.process(start.clone(), None, false).unwrap()
);
let pair_expected = Encoding::new(
vec![0; 10],
vec![],
vec![
"Ġ".into(),
"ĠĠĠĠHelloĠĠ".into(),
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
"Ġ".into(),
"ĠĠĠĠHelloĠĠ".into(),
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
],
vec![],
vec![
(0, 0),
(4, 9),
(13, 18),
(18, 23),
(29, 29),
(0, 0),
(4, 9),
(13, 18),
(18, 23),
(29, 29),
],
vec![],
vec![],
vec![],
HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
);
assert_eq!(
pair_expected,
bytelevel
.process(start.clone(), Some(start.clone()), false)
.unwrap()
);
assert_eq!(
pair_expected,
sequence.process(start.clone(), Some(start), false).unwrap()
);
}
}