mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-31 12:39:21 +00:00
Adding Sequence
for PostProcessor
. (#1052)
* Adding `Sequence` for `PostProcessor`. * Fixing node? Writing in the dark here, don't have Python2.7 * `undefined` is not accepted. * Other test.
This commit is contained in:
@ -54,3 +54,11 @@ export function templateProcessing(
|
|||||||
pair?: string,
|
pair?: string,
|
||||||
specialTokens?: [string, number][]
|
specialTokens?: [string, number][]
|
||||||
): PostProcessor;
|
): PostProcessor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Instantiate a new SequenceProcessing.
|
||||||
|
*
|
||||||
|
* @param PostProcessor[] The list of Processors to use
|
||||||
|
* @since 0.13.0
|
||||||
|
*/
|
||||||
|
export function sequenceProcessing(processors: PostProcessor[]): PostProcessor;
|
||||||
|
@ -5,4 +5,5 @@ module.exports = {
|
|||||||
byteLevelProcessing: native.processors_ByteLevel,
|
byteLevelProcessing: native.processors_ByteLevel,
|
||||||
robertaProcessing: native.processors_RobertaProcessing,
|
robertaProcessing: native.processors_RobertaProcessing,
|
||||||
templateProcessing: native.processors_TemplateProcessing,
|
templateProcessing: native.processors_TemplateProcessing,
|
||||||
|
sequenceProcessing: native.processors_Sequence,
|
||||||
};
|
};
|
||||||
|
@ -4,6 +4,7 @@ import {
|
|||||||
bertProcessing,
|
bertProcessing,
|
||||||
byteLevelProcessing,
|
byteLevelProcessing,
|
||||||
robertaProcessing,
|
robertaProcessing,
|
||||||
|
sequenceProcessing,
|
||||||
templateProcessing,
|
templateProcessing,
|
||||||
} from "./post-processors";
|
} from "./post-processors";
|
||||||
|
|
||||||
@ -81,3 +82,14 @@ describe("templateProcessing", () => {
|
|||||||
expect(processor.constructor.name).toEqual("Processor");
|
expect(processor.constructor.name).toEqual("Processor");
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("sequenceProcessing", () => {
|
||||||
|
it("accepts `PostProcessor[]` as first parameter", () => {
|
||||||
|
const template = templateProcessing("[CLS] $A [SEP]", "[CLS] $A [SEP] $B:1 [SEP]:1", [
|
||||||
|
["[CLS]", 1],
|
||||||
|
["[SEP]", 2],
|
||||||
|
]);
|
||||||
|
const bytelevel = byteLevelProcessing(true);
|
||||||
|
expect(sequenceProcessing([bytelevel, template])).toBeDefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
@ -129,6 +129,33 @@ fn template_processing(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
|
|||||||
Ok(js_processor)
|
Ok(js_processor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// sequence(processors: List[Processor])
|
||||||
|
fn sequence(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
|
||||||
|
let processors = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
|
||||||
|
let mut sequence = Vec::with_capacity(processors.len());
|
||||||
|
|
||||||
|
processors.into_iter().try_for_each(|processor| {
|
||||||
|
match processor.downcast::<JsPostProcessor>().or_throw(&mut cx) {
|
||||||
|
Ok(processor) => {
|
||||||
|
let guard = cx.lock();
|
||||||
|
if let Some(processor_arc) = &processor.borrow(&guard).processor {
|
||||||
|
let processor: PostProcessorWrapper = (**processor_arc).clone();
|
||||||
|
sequence.push(processor);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(e) => Err(e),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut pretok = JsPostProcessor::new::<_, JsPostProcessor, _>(&mut cx, vec![])?;
|
||||||
|
let guard = cx.lock();
|
||||||
|
pretok.borrow_mut(&guard).processor = Some(Arc::new(PostProcessorWrapper::Sequence(
|
||||||
|
tk::processors::sequence::Sequence::new(sequence),
|
||||||
|
)));
|
||||||
|
Ok(pretok)
|
||||||
|
}
|
||||||
|
|
||||||
/// Register everything here
|
/// Register everything here
|
||||||
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||||
m.export_function(&format!("{}_BertProcessing", prefix), bert_processing)?;
|
m.export_function(&format!("{}_BertProcessing", prefix), bert_processing)?;
|
||||||
@ -138,5 +165,6 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
|||||||
&format!("{}_TemplateProcessing", prefix),
|
&format!("{}_TemplateProcessing", prefix),
|
||||||
template_processing,
|
template_processing,
|
||||||
)?;
|
)?;
|
||||||
|
m.export_function(&format!("{}_Sequence", prefix), sequence)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -5,4 +5,5 @@ PostProcessor = processors.PostProcessor
|
|||||||
BertProcessing = processors.BertProcessing
|
BertProcessing = processors.BertProcessing
|
||||||
ByteLevel = processors.ByteLevel
|
ByteLevel = processors.ByteLevel
|
||||||
RobertaProcessing = processors.RobertaProcessing
|
RobertaProcessing = processors.RobertaProcessing
|
||||||
|
Sequence = processors.Sequence
|
||||||
TemplateProcessing = processors.TemplateProcessing
|
TemplateProcessing = processors.TemplateProcessing
|
||||||
|
@ -193,6 +193,48 @@ class RobertaProcessing(PostProcessor):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class Sequence(PostProcessor):
|
||||||
|
"""
|
||||||
|
Sequence Processor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processors (:obj:`List[PostProcessor]`)
|
||||||
|
The processors that need to be chained
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, processors):
|
||||||
|
pass
|
||||||
|
def num_special_tokens_to_add(self, is_pair):
|
||||||
|
"""
|
||||||
|
Return the number of special tokens that would be added for single/pair sentences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_pair (:obj:`bool`):
|
||||||
|
Whether the input would be a pair of sequences
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`int`: The number of tokens to add
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def process(self, encoding, pair=None, add_special_tokens=True):
|
||||||
|
"""
|
||||||
|
Post-process the given encodings, generating the final one
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoding (:class:`~tokenizers.Encoding`):
|
||||||
|
The encoding for the first sequence
|
||||||
|
|
||||||
|
pair (:class:`~tokenizers.Encoding`, `optional`):
|
||||||
|
The encoding for the pair sequence
|
||||||
|
|
||||||
|
add_special_tokens (:obj:`bool`):
|
||||||
|
Whether to add the special tokens
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:class:`~tokenizers.Encoding`: The final encoding
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
class TemplateProcessing(PostProcessor):
|
class TemplateProcessing(PostProcessor):
|
||||||
"""
|
"""
|
||||||
Provides a way to specify templates in order to add the special tokens to each
|
Provides a way to specify templates in order to add the special tokens to each
|
||||||
|
@ -104,6 +104,7 @@ fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add_class::<processors::PyRobertaProcessing>()?;
|
m.add_class::<processors::PyRobertaProcessing>()?;
|
||||||
m.add_class::<processors::PyByteLevel>()?;
|
m.add_class::<processors::PyByteLevel>()?;
|
||||||
m.add_class::<processors::PyTemplateProcessing>()?;
|
m.add_class::<processors::PyTemplateProcessing>()?;
|
||||||
|
m.add_class::<processors::PySequence>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
use tk::processors::bert::BertProcessing;
|
use tk::processors::bert::BertProcessing;
|
||||||
use tk::processors::byte_level::ByteLevel;
|
use tk::processors::byte_level::ByteLevel;
|
||||||
use tk::processors::roberta::RobertaProcessing;
|
use tk::processors::roberta::RobertaProcessing;
|
||||||
|
use tk::processors::sequence::Sequence;
|
||||||
use tk::processors::template::{SpecialToken, Template};
|
use tk::processors::template::{SpecialToken, Template};
|
||||||
use tk::processors::PostProcessorWrapper;
|
use tk::processors::PostProcessorWrapper;
|
||||||
use tk::{Encoding, PostProcessor};
|
use tk::{Encoding, PostProcessor};
|
||||||
@ -50,6 +51,7 @@ impl PyPostProcessor {
|
|||||||
PostProcessorWrapper::Template(_) => {
|
PostProcessorWrapper::Template(_) => {
|
||||||
Py::new(py, (PyTemplateProcessing {}, base))?.into_py(py)
|
Py::new(py, (PyTemplateProcessing {}, base))?.into_py(py)
|
||||||
}
|
}
|
||||||
|
PostProcessorWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -414,6 +416,37 @@ impl PyTemplateProcessing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sequence Processor
|
||||||
|
///
|
||||||
|
/// Args:
|
||||||
|
/// processors (:obj:`List[PostProcessor]`)
|
||||||
|
/// The processors that need to be chained
|
||||||
|
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name = "Sequence")]
|
||||||
|
#[pyo3(text_signature = "(self, processors)")]
|
||||||
|
pub struct PySequence {}
|
||||||
|
#[pymethods]
|
||||||
|
impl PySequence {
|
||||||
|
#[new]
|
||||||
|
#[args(processors)]
|
||||||
|
fn new(processors_py: &PyList) -> (Self, PyPostProcessor) {
|
||||||
|
let mut processors: Vec<PostProcessorWrapper> = Vec::with_capacity(processors_py.len());
|
||||||
|
for n in processors_py.iter() {
|
||||||
|
let processor: PyRef<PyPostProcessor> = n.extract().unwrap();
|
||||||
|
let processor = processor.processor.as_ref();
|
||||||
|
processors.push(processor.clone());
|
||||||
|
}
|
||||||
|
let sequence_processor = Sequence::new(processors);
|
||||||
|
(
|
||||||
|
PySequence {},
|
||||||
|
PyPostProcessor::new(Arc::new(PostProcessorWrapper::Sequence(sequence_processor))),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __getnewargs__<'p>(&self, py: Python<'p>) -> &'p PyTuple {
|
||||||
|
PyTuple::new(py, &[PyList::empty(py)])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -13,6 +13,7 @@ from tokenizers.processors import (
|
|||||||
RobertaProcessing,
|
RobertaProcessing,
|
||||||
ByteLevel,
|
ByteLevel,
|
||||||
TemplateProcessing,
|
TemplateProcessing,
|
||||||
|
Sequence,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -179,3 +180,49 @@ class TestTemplateProcessing:
|
|||||||
tokenizer.post_processor = self.get_roberta()
|
tokenizer.post_processor = self.get_roberta()
|
||||||
template = tokenizer.encode("my name is john", "pair")
|
template = tokenizer.encode("my name is john", "pair")
|
||||||
assert original.ids == template.ids
|
assert original.ids == template.ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceProcessing:
|
||||||
|
def test_sequence_processing(self):
|
||||||
|
assert Sequence([]) is not None
|
||||||
|
assert Sequence([ByteLevel()]) is not None
|
||||||
|
assert isinstance(Sequence([]), PostProcessor)
|
||||||
|
assert isinstance(Sequence([]), Sequence)
|
||||||
|
serialized = pickle.dumps(Sequence([]))
|
||||||
|
assert isinstance(pickle.loads(serialized), Sequence)
|
||||||
|
|
||||||
|
def test_post_process(self):
|
||||||
|
byte_level = ByteLevel(trim_offsets=True)
|
||||||
|
template = TemplateProcessing(
|
||||||
|
single=["[CLS]", "$0", "[SEP]"],
|
||||||
|
pair=["[CLS]:0", "$A", "[SEP]:0", "$B:1", "[SEP]:1"],
|
||||||
|
special_tokens=[("[CLS]", 1), ("[SEP]", 0)],
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(BPE())
|
||||||
|
tokenizer.add_special_tokens(["[SEP]", "[CLS]"])
|
||||||
|
tokenizer.add_tokens(["my", "name", "is", "Ġjohn", "pair"])
|
||||||
|
tokenizer.post_processor = template
|
||||||
|
|
||||||
|
# Before the sequence
|
||||||
|
original = tokenizer.encode("my name is Ġjohn")
|
||||||
|
assert original.ids == [1, 2, 3, 4, 5, 0]
|
||||||
|
assert original.type_ids == [0, 0, 0, 0, 0, 0]
|
||||||
|
assert original.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (11, 16), (0, 0)]
|
||||||
|
pair = tokenizer.encode("my name is Ġjohn", "pair")
|
||||||
|
# assert pair.ids == [1, 2, 3, 4, 5, 0, 6, 0]
|
||||||
|
assert pair.type_ids == [0, 0, 0, 0, 0, 0, 1, 1]
|
||||||
|
assert pair.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (11, 16), (0, 0), (0, 4), (0, 0)]
|
||||||
|
|
||||||
|
processor = Sequence([byte_level, template])
|
||||||
|
tokenizer.post_processor = processor
|
||||||
|
|
||||||
|
original = tokenizer.encode("my name is Ġjohn")
|
||||||
|
assert original.ids == [1, 2, 3, 4, 5, 0]
|
||||||
|
assert original.type_ids == [0, 0, 0, 0, 0, 0]
|
||||||
|
# Offsets ARE trimmed
|
||||||
|
assert original.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (12, 16), (0, 0)]
|
||||||
|
pair = tokenizer.encode("my name is Ġjohn", "pair")
|
||||||
|
# assert pair.ids == [1, 2, 3, 4, 5, 0, 6, 0]
|
||||||
|
assert pair.type_ids == [0, 0, 0, 0, 0, 0, 1, 1]
|
||||||
|
assert pair.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (12, 16), (0, 0), (0, 4), (0, 0)]
|
||||||
|
@ -177,7 +177,7 @@ impl PostProcessor for ByteLevel {
|
|||||||
fn process_encodings(
|
fn process_encodings(
|
||||||
&self,
|
&self,
|
||||||
mut encodings: Vec<Encoding>,
|
mut encodings: Vec<Encoding>,
|
||||||
add_special_tokens: bool,
|
_add_special_tokens: bool,
|
||||||
) -> Result<Vec<Encoding>> {
|
) -> Result<Vec<Encoding>> {
|
||||||
if self.trim_offsets {
|
if self.trim_offsets {
|
||||||
for encoding in encodings.iter_mut() {
|
for encoding in encodings.iter_mut() {
|
||||||
@ -188,7 +188,11 @@ impl PostProcessor for ByteLevel {
|
|||||||
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
|
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
<dyn PostProcessor>::default_process(encodings, add_special_tokens)
|
for (i, encoding) in encodings.iter_mut().enumerate() {
|
||||||
|
encoding.set_sequence_id(i);
|
||||||
|
}
|
||||||
|
Ok(encodings)
|
||||||
|
//<dyn PostProcessor>::default_process(encodings, add_special_tokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -493,7 +497,7 @@ mod tests {
|
|||||||
vec![],
|
vec![],
|
||||||
vec![],
|
vec![],
|
||||||
vec![],
|
vec![],
|
||||||
HashMap::new(),
|
HashMap::from_iter(vec![(0, 0..5)]),
|
||||||
);
|
);
|
||||||
|
|
||||||
let bytelevel = ByteLevel::default().trim_offsets(true);
|
let bytelevel = ByteLevel::default().trim_offsets(true);
|
||||||
@ -502,8 +506,8 @@ mod tests {
|
|||||||
bytelevel.process(start.clone(), None, false).unwrap()
|
bytelevel.process(start.clone(), None, false).unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut pair_expected = Encoding::new(
|
let pair_expected = Encoding::new(
|
||||||
vec![0; 5],
|
vec![0; 10],
|
||||||
vec![],
|
vec![],
|
||||||
vec![
|
vec![
|
||||||
"Ġ".into(),
|
"Ġ".into(),
|
||||||
@ -511,15 +515,30 @@ mod tests {
|
|||||||
"ĠĠHello".into(),
|
"ĠĠHello".into(),
|
||||||
"HelloĠĠ".into(),
|
"HelloĠĠ".into(),
|
||||||
"ĠĠĠĠ".into(),
|
"ĠĠĠĠ".into(),
|
||||||
|
"Ġ".into(),
|
||||||
|
"ĠĠĠĠHelloĠĠ".into(),
|
||||||
|
"ĠĠHello".into(),
|
||||||
|
"HelloĠĠ".into(),
|
||||||
|
"ĠĠĠĠ".into(),
|
||||||
],
|
],
|
||||||
vec![],
|
vec![],
|
||||||
vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)],
|
vec![
|
||||||
|
(0, 0),
|
||||||
|
(4, 9),
|
||||||
|
(13, 18),
|
||||||
|
(18, 23),
|
||||||
|
(29, 29),
|
||||||
|
(0, 0),
|
||||||
|
(4, 9),
|
||||||
|
(13, 18),
|
||||||
|
(18, 23),
|
||||||
|
(29, 29),
|
||||||
|
],
|
||||||
vec![],
|
vec![],
|
||||||
vec![],
|
vec![],
|
||||||
vec![],
|
vec![],
|
||||||
HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
|
HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
|
||||||
);
|
);
|
||||||
pair_expected.merge_with(expected, false);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
pair_expected,
|
pair_expected,
|
||||||
bytelevel
|
bytelevel
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
pub mod bert;
|
pub mod bert;
|
||||||
pub mod roberta;
|
pub mod roberta;
|
||||||
|
pub mod sequence;
|
||||||
pub mod template;
|
pub mod template;
|
||||||
|
|
||||||
// Re-export these as processors
|
// Re-export these as processors
|
||||||
@ -10,6 +11,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
use crate::pre_tokenizers::byte_level::ByteLevel;
|
use crate::pre_tokenizers::byte_level::ByteLevel;
|
||||||
use crate::processors::bert::BertProcessing;
|
use crate::processors::bert::BertProcessing;
|
||||||
use crate::processors::roberta::RobertaProcessing;
|
use crate::processors::roberta::RobertaProcessing;
|
||||||
|
use crate::processors::sequence::Sequence;
|
||||||
use crate::processors::template::TemplateProcessing;
|
use crate::processors::template::TemplateProcessing;
|
||||||
use crate::{Encoding, PostProcessor, Result};
|
use crate::{Encoding, PostProcessor, Result};
|
||||||
|
|
||||||
@ -21,6 +23,7 @@ pub enum PostProcessorWrapper {
|
|||||||
Bert(BertProcessing),
|
Bert(BertProcessing),
|
||||||
ByteLevel(ByteLevel),
|
ByteLevel(ByteLevel),
|
||||||
Template(TemplateProcessing),
|
Template(TemplateProcessing),
|
||||||
|
Sequence(Sequence),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PostProcessor for PostProcessorWrapper {
|
impl PostProcessor for PostProcessorWrapper {
|
||||||
@ -30,6 +33,7 @@ impl PostProcessor for PostProcessorWrapper {
|
|||||||
Self::ByteLevel(bl) => bl.added_tokens(is_pair),
|
Self::ByteLevel(bl) => bl.added_tokens(is_pair),
|
||||||
Self::Roberta(roberta) => roberta.added_tokens(is_pair),
|
Self::Roberta(roberta) => roberta.added_tokens(is_pair),
|
||||||
Self::Template(template) => template.added_tokens(is_pair),
|
Self::Template(template) => template.added_tokens(is_pair),
|
||||||
|
Self::Sequence(bl) => bl.added_tokens(is_pair),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,6 +47,7 @@ impl PostProcessor for PostProcessorWrapper {
|
|||||||
Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
|
Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
|
||||||
Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
|
Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
|
||||||
Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
|
Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
|
||||||
|
Self::Sequence(bl) => bl.process_encodings(encodings, add_special_tokens),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -51,6 +56,7 @@ impl_enum_from!(BertProcessing, PostProcessorWrapper, Bert);
|
|||||||
impl_enum_from!(ByteLevel, PostProcessorWrapper, ByteLevel);
|
impl_enum_from!(ByteLevel, PostProcessorWrapper, ByteLevel);
|
||||||
impl_enum_from!(RobertaProcessing, PostProcessorWrapper, Roberta);
|
impl_enum_from!(RobertaProcessing, PostProcessorWrapper, Roberta);
|
||||||
impl_enum_from!(TemplateProcessing, PostProcessorWrapper, Template);
|
impl_enum_from!(TemplateProcessing, PostProcessorWrapper, Template);
|
||||||
|
impl_enum_from!(Sequence, PostProcessorWrapper, Sequence);
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
139
tokenizers/src/processors/sequence.rs
Normal file
139
tokenizers/src/processors/sequence.rs
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
use crate::processors::PostProcessorWrapper;
|
||||||
|
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
||||||
|
use crate::utils::macro_rules_attribute;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||||
|
#[macro_rules_attribute(impl_serde_type!)]
|
||||||
|
pub struct Sequence {
|
||||||
|
processors: Vec<PostProcessorWrapper>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sequence {
|
||||||
|
pub fn new(processors: Vec<PostProcessorWrapper>) -> Self {
|
||||||
|
Self { processors }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PostProcessor for Sequence {
|
||||||
|
fn added_tokens(&self, is_pair: bool) -> usize {
|
||||||
|
self.processors
|
||||||
|
.iter()
|
||||||
|
.map(|p| p.added_tokens(is_pair))
|
||||||
|
.sum::<usize>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process_encodings(
|
||||||
|
&self,
|
||||||
|
mut encodings: Vec<Encoding>,
|
||||||
|
add_special_tokens: bool,
|
||||||
|
) -> Result<Vec<Encoding>> {
|
||||||
|
for processor in &self.processors {
|
||||||
|
encodings = processor.process_encodings(encodings, add_special_tokens)?;
|
||||||
|
}
|
||||||
|
Ok(encodings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::processors::{ByteLevel, PostProcessorWrapper};
|
||||||
|
use crate::tokenizer::{Encoding, PostProcessor};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::iter::FromIterator;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn process_chain() {
|
||||||
|
let start = Encoding::new(
|
||||||
|
vec![0; 5],
|
||||||
|
vec![],
|
||||||
|
vec![
|
||||||
|
"Ġ".into(),
|
||||||
|
"ĠĠĠĠHelloĠĠ".into(),
|
||||||
|
"ĠĠHello".into(),
|
||||||
|
"HelloĠĠ".into(),
|
||||||
|
"ĠĠĠĠ".into(),
|
||||||
|
],
|
||||||
|
vec![],
|
||||||
|
vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)],
|
||||||
|
vec![],
|
||||||
|
vec![],
|
||||||
|
vec![],
|
||||||
|
HashMap::new(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let bytelevel = ByteLevel::default().trim_offsets(true);
|
||||||
|
let sequence = Sequence::new(vec![PostProcessorWrapper::ByteLevel(bytelevel)]);
|
||||||
|
let expected = Encoding::new(
|
||||||
|
vec![0; 5],
|
||||||
|
vec![],
|
||||||
|
vec![
|
||||||
|
"Ġ".into(),
|
||||||
|
"ĠĠĠĠHelloĠĠ".into(),
|
||||||
|
"ĠĠHello".into(),
|
||||||
|
"HelloĠĠ".into(),
|
||||||
|
"ĠĠĠĠ".into(),
|
||||||
|
],
|
||||||
|
vec![],
|
||||||
|
vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)],
|
||||||
|
vec![],
|
||||||
|
vec![],
|
||||||
|
vec![],
|
||||||
|
HashMap::from_iter(vec![(0, 0..5)]),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
expected,
|
||||||
|
bytelevel.process(start.clone(), None, false).unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
expected,
|
||||||
|
sequence.process(start.clone(), None, false).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
let pair_expected = Encoding::new(
|
||||||
|
vec![0; 10],
|
||||||
|
vec![],
|
||||||
|
vec![
|
||||||
|
"Ġ".into(),
|
||||||
|
"ĠĠĠĠHelloĠĠ".into(),
|
||||||
|
"ĠĠHello".into(),
|
||||||
|
"HelloĠĠ".into(),
|
||||||
|
"ĠĠĠĠ".into(),
|
||||||
|
"Ġ".into(),
|
||||||
|
"ĠĠĠĠHelloĠĠ".into(),
|
||||||
|
"ĠĠHello".into(),
|
||||||
|
"HelloĠĠ".into(),
|
||||||
|
"ĠĠĠĠ".into(),
|
||||||
|
],
|
||||||
|
vec![],
|
||||||
|
vec![
|
||||||
|
(0, 0),
|
||||||
|
(4, 9),
|
||||||
|
(13, 18),
|
||||||
|
(18, 23),
|
||||||
|
(29, 29),
|
||||||
|
(0, 0),
|
||||||
|
(4, 9),
|
||||||
|
(13, 18),
|
||||||
|
(18, 23),
|
||||||
|
(29, 29),
|
||||||
|
],
|
||||||
|
vec![],
|
||||||
|
vec![],
|
||||||
|
vec![],
|
||||||
|
HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
pair_expected,
|
||||||
|
bytelevel
|
||||||
|
.process(start.clone(), Some(start.clone()), false)
|
||||||
|
.unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
pair_expected,
|
||||||
|
sequence.process(start.clone(), Some(start), false).unwrap()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user