diff --git a/bindings/python/py_src/tokenizers/processors/__init__.py b/bindings/python/py_src/tokenizers/processors/__init__.py index 7b9ae056..ec00a6a0 100644 --- a/bindings/python/py_src/tokenizers/processors/__init__.py +++ b/bindings/python/py_src/tokenizers/processors/__init__.py @@ -4,3 +4,4 @@ PostProcessor = processors.PostProcessor BertProcessing = processors.BertProcessing RobertaProcessing = processors.RobertaProcessing ByteLevel = processors.ByteLevel +TemplateProcessing = processors.TemplateProcessing diff --git a/bindings/python/py_src/tokenizers/processors/__init__.pyi b/bindings/python/py_src/tokenizers/processors/__init__.pyi index 61143214..5cde2b90 100644 --- a/bindings/python/py_src/tokenizers/processors/__init__.pyi +++ b/bindings/python/py_src/tokenizers/processors/__init__.pyi @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Union, List class PostProcessor: """ Base class for all post-processors @@ -89,7 +89,7 @@ class ByteLevel(PostProcessor): want the offsets to include these whitespaces, then this PostProcessor must be used. """ - def __init(self, trim_offsets: bool = True) -> None: + def __init__(self, trim_offsets: bool = True) -> None: """ Instantiate a new ByteLevel Args: @@ -97,3 +97,67 @@ class ByteLevel(PostProcessor): Whether to trim the whitespaces from the produced offsets. """ pass + +Template = Union[str, List[str]] +Tokens = List[Union[Tuple[int, str], Tuple[str, int], dict]] + +class TemplateProcessing(PostProcessor): + """ TemplateProcessing + + Provides a way to specify templates in order to add the special tokens to each + input sequence as relevant. + + Let's take `BERT` tokenizer as an example. It uses two special tokens, used to + delimitate each sequence. `[CLS]` is always used at the beginning of the first + sequence, and `[SEP]` is added at the end of both the first, and the pair + sequences. The final result looks like this: + - Single sequence: `[CLS] Hello there [SEP]` + - Pair sequences: `[CLS] My name is Anthony [SEP] What is my name? [SEP]` + + You can achieve such behavior using a TemplateProcessing: + ``` + TemplateProcessing( + seq_a="[CLS] $0 [SEP]", + seq_b="$1 [SEP]", + special_tokens=[("[CLS]", 1), ("[SEP]", 0)], + ) + ``` + + In this example, $0 and $1 both represent the input sequences. The number in this + identifier is actually the default type_id that will be used for each sequence. So, + in this case, the first sequence will use 0, while the pair sequence will use 1. + + Note that we are saying the "default" type_id because each SpecialToken can define + its own type_id which would override the provided default. + """ + + def __init__(self, seq_a: Template, seq_b: Template, special_tokens: Tokens) -> None: + """ Instantiate a new TemplateProcessing + + Args: + seq_a: Template + The template for the first sequence. + + seq_b: Template: + The template for the pair sequence. + + special_tokens: Tokens: + The list of special tokens used in each sequences + + Template: Union[str, List[str]]: + - If a `str` is provided, the whitespace is used as delimiter between tokens + - If a `List[str]` is provided, a list of tokens + + Tokens: List[Union[Tuple[int, str], Tuple[str, int], dict]]: + - A Tuple with both a token and its associated ID, in any order + - A dict with the following keys: + - "id": str => The special token id, as specified in the Template + - "ids": List[int] => The associated IDs + - "tokens": List[str] => The associated tokens + - "type_ids": Optional[List[Optional[int]]] => If specified, a list of optional + type_ids. In the `type_id` is not specified, the one from the input sequence + will be used. + The given dict expects the provided `ids`, `tokens` and `type_ids` lists to have + the same length. + """ + pass diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 0ecbd247..a04449d5 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -1,3 +1,5 @@ +#![warn(clippy::all)] + extern crate tokenizers as tk; mod decoders; @@ -90,6 +92,7 @@ fn processors(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index bbd530b6..039e7b49 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -4,10 +4,12 @@ use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; +use crate::error::ToPyResult; use serde::{Deserialize, Serialize}; use tk::processors::bert::BertProcessing; use tk::processors::byte_level::ByteLevel; use tk::processors::roberta::RobertaProcessing; +use tk::processors::template::{SpecialToken, Template}; use tk::processors::PostProcessorWrapper; use tk::{Encoding, PostProcessor}; use tokenizers as tk; @@ -38,6 +40,9 @@ impl PyPostProcessor { PostProcessorWrapper::Roberta(_) => { Py::new(py, (PyRobertaProcessing {}, base)).map(Into::into) } + PostProcessorWrapper::Template(_) => { + Py::new(py, (PyTemplateProcessing {}, base)).map(Into::into) + } } } } @@ -158,6 +163,104 @@ impl PyByteLevel { } } +#[derive(Clone, Debug)] +pub struct PySpecialToken(SpecialToken); + +impl From for SpecialToken { + fn from(v: PySpecialToken) -> Self { + v.0 + } +} + +impl FromPyObject<'_> for PySpecialToken { + fn extract(ob: &PyAny) -> PyResult { + if let Ok(v) = ob.extract::<(String, u32)>() { + Ok(Self(v.into())) + } else if let Ok(v) = ob.extract::<(u32, String)>() { + Ok(Self(v.into())) + } else if let Ok(d) = ob.downcast::() { + let id = d + .get_item("id") + .ok_or_else(|| exceptions::ValueError::py_err("`id` must be specified"))? + .extract::()?; + let ids = d + .get_item("ids") + .ok_or_else(|| exceptions::ValueError::py_err("`ids` must be specified"))? + .extract::>()?; + let type_ids = d.get_item("type_ids").map_or_else( + || Ok(vec![None; ids.len()]), + |v| v.extract::>>(), + )?; + let tokens = d + .get_item("tokens") + .ok_or_else(|| exceptions::ValueError::py_err("`tokens` must be specified"))? + .extract::>()?; + + Ok(Self( + ToPyResult(SpecialToken::new(id, ids, type_ids, tokens)).into_py()?, + )) + } else { + Err(exceptions::TypeError::py_err( + "Expected Union[Tuple[str, int], Tuple[int, str], dict]", + )) + } + } +} + +#[derive(Clone, Debug)] +pub struct PyTemplate(Template); + +impl From for Template { + fn from(v: PyTemplate) -> Self { + v.0 + } +} + +impl FromPyObject<'_> for PyTemplate { + fn extract(ob: &PyAny) -> PyResult { + if let Ok(s) = ob.extract::<&str>() { + Ok(Self(s.into())) + } else if let Ok(s) = ob.extract::>() { + Ok(Self(s.into())) + } else { + Err(exceptions::TypeError::py_err( + "Expected Union[str, List[str]]", + )) + } + } +} + +#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=TemplateProcessing)] +pub struct PyTemplateProcessing {} +#[pymethods] +impl PyTemplateProcessing { + #[new] + #[args(seq_a = "None", seq_b = "None", special_tokens = "None")] + fn new( + seq_a: Option, + seq_b: Option, + special_tokens: Option>, + ) -> PyResult<(Self, PyPostProcessor)> { + let mut builder = tk::processors::template::TemplateProcessing::builder(); + + if let Some(seq) = seq_a { + builder.sequence_a(seq); + } + if let Some(seq) = seq_b { + builder.sequence_b(seq); + } + if let Some(sp) = special_tokens { + builder.special_tokens(sp); + } + let processor = builder.build().map_err(exceptions::ValueError::py_err)?; + + Ok(( + PyTemplateProcessing {}, + PyPostProcessor::new(Arc::new(processor.into())), + )) + } +} + #[cfg(test)] mod test { use std::sync::Arc; diff --git a/bindings/python/tests/bindings/test_processors.py b/bindings/python/tests/bindings/test_processors.py index 5f6d2f06..aba1f14b 100644 --- a/bindings/python/tests/bindings/test_processors.py +++ b/bindings/python/tests/bindings/test_processors.py @@ -5,7 +5,13 @@ from ..utils import data_dir, roberta_files from tokenizers import Tokenizer from tokenizers.models import BPE from tokenizers.pre_tokenizers import ByteLevel as ByteLevelPreTokenizer -from tokenizers.processors import PostProcessor, BertProcessing, RobertaProcessing, ByteLevel +from tokenizers.processors import ( + PostProcessor, + BertProcessing, + RobertaProcessing, + ByteLevel, + TemplateProcessing, +) class TestBertProcessing: @@ -73,3 +79,78 @@ class TestByteLevelProcessing: output = tokenizer.encode("My name is John") assert output.tokens == ["ĠMy", "Ġname", "Ġis", "ĠJohn"] assert output.offsets == [(0, 2), (3, 7), (8, 10), (11, 15)] + + +class TestTemplateProcessing: + def get_bert(self): + return TemplateProcessing( + seq_a=["[CLS]", "$0", "[SEP]"], + seq_b=["$1", "[SEP]"], + special_tokens=[("[CLS]", 1), ("[SEP]", 0)], + ) + + def get_roberta(self): + return TemplateProcessing( + seq_a=" $0 ", seq_b=" $0 ", special_tokens=[("", 0), ("", 1)], + ) + + def get_t5_squad(self): + # >>> from transformers import AutoTokenizer + # >>> tok = AutoTokenizer.from_pretrained("t5-small") + # >>> tok.tokenize("question: ") + # ['▁question', ':'] + # >>> tok.tokenize("context: ") + # ['▁context', ':'] + # >>> tok.encode("context: ") + # [2625, 10] + # >>> tok.encode("question: ") + # [822, 10] + + return TemplateProcessing( + seq_a=["Q", "$0"], + seq_b=["C", "$1"], + special_tokens=[ + { + "id": "Q", + "ids": [2625, 10], + "type_ids": [None, None], + "tokens": ["_question", ":"], + }, + { + "id": "C", + "ids": [822, 10], + "type_ids": [None, None], + "tokens": ["_context", ":"], + }, + ], + ) + + def test_instantiate(self): + bert = self.get_bert() + assert bert is not None + assert isinstance(bert, PostProcessor) + assert isinstance(bert, TemplateProcessing) + assert isinstance(pickle.loads(pickle.dumps(bert)), TemplateProcessing) + + def test_bert_parity(self): + tokenizer = Tokenizer(BPE()) + tokenizer.add_special_tokens(["[SEP]", "[CLS]"]) + tokenizer.add_tokens(["my", "name", "is", "john", "pair"]) + tokenizer.post_processor = BertProcessing(("[SEP]", 0), ("[CLS]", 1)) + + original = tokenizer.encode("my name", "pair") + + tokenizer.post_processor = self.get_bert() + template = tokenizer.encode("my name", "pair") + assert original.ids == template.ids + + def test_roberta_parity(self): + tokenizer = Tokenizer(BPE()) + tokenizer.add_special_tokens(["", ""]) + tokenizer.add_tokens(["my", "name", "is", "john", "pair"]) + tokenizer.post_processor = RobertaProcessing(("", 1), ("", 0)) + + original = tokenizer.encode("my name is john", "pair") + tokenizer.post_processor = self.get_roberta() + template = tokenizer.encode("my name is john", "pair") + assert original.ids == template.ids