mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 13:18:31 +00:00
Python - Update bindings for TemplateProcessing
This commit is contained in:
@@ -119,22 +119,29 @@ class TemplateProcessing(PostProcessor):
|
||||
sequences. The final result looks like this:
|
||||
- Single sequence: `[CLS] Hello there [SEP]`
|
||||
- Pair sequences: `[CLS] My name is Anthony [SEP] What is my name? [SEP]`
|
||||
With the type ids as following:
|
||||
```markdown
|
||||
[CLS] ... [SEP] ... [SEP]
|
||||
0 0 0 1 1
|
||||
```
|
||||
|
||||
You can achieve such behavior using a TemplateProcessing:
|
||||
```
|
||||
TemplateProcessing(
|
||||
seq_a="[CLS] $0 [SEP]",
|
||||
seq_b="$1 [SEP]",
|
||||
single="[CLS] $0 [SEP]",
|
||||
pair="[CLS] $A [SEP] $B:1 [SEP]:1",
|
||||
special_tokens=[("[CLS]", 1), ("[SEP]", 0)],
|
||||
)
|
||||
```
|
||||
|
||||
In this example, $0 and $1 both represent the input sequences. The number in this
|
||||
identifier is actually the default type_id that will be used for each sequence. So,
|
||||
in this case, the first sequence will use 0, while the pair sequence will use 1.
|
||||
In this example, each input sequence is identified using a `$` construct. This identifier
|
||||
lets us specify each input sequence, and the type_id to use. When nothing is specified,
|
||||
it uses the default values. Here are the different ways to specify it:
|
||||
- Specifying the sequence, with default `type_id == 0`: `$A` or `$B`
|
||||
- Specifying the `type_id` with default `sequence == A`: `$0`, `$1`, `$2`, ...
|
||||
- Specifying both: `$A:0`, `$B:1`, ...
|
||||
|
||||
Note that we are saying the "default" type_id because each SpecialToken can define
|
||||
its own type_id which would override the provided default.
|
||||
The same construct is used for special tokens: `<identifier>(:<type_id>)?`.
|
||||
|
||||
**Warning**: You must ensure that you are giving the correct tokens/ids as these
|
||||
will be added to the Encoding without any further check. If the given ids correspond
|
||||
@@ -142,15 +149,15 @@ class TemplateProcessing(PostProcessor):
|
||||
might lead to unexpected results.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_a: Template, seq_b: Template, special_tokens: Tokens) -> None:
|
||||
def __init__(self, single: Template, seq_b: Template, special_tokens: Tokens) -> None:
|
||||
"""Instantiate a new TemplateProcessing
|
||||
|
||||
Args:
|
||||
seq_a: Template
|
||||
The template for the first sequence.
|
||||
single: Template
|
||||
The template used for single sequences
|
||||
|
||||
seq_b: Template:
|
||||
The template for the pair sequence.
|
||||
pair: Template:
|
||||
The template used when both sequences are specified
|
||||
|
||||
special_tokens: Tokens:
|
||||
The list of special tokens used in each sequences
|
||||
@@ -165,10 +172,7 @@ class TemplateProcessing(PostProcessor):
|
||||
- "id": str => The special token id, as specified in the Template
|
||||
- "ids": List[int] => The associated IDs
|
||||
- "tokens": List[str] => The associated tokens
|
||||
- "type_ids": Optional[List[Optional[int]]] => If specified, a list of optional
|
||||
type_ids. In the `type_id` is not specified, the one from the input sequence
|
||||
will be used.
|
||||
The given dict expects the provided `ids`, `tokens` and `type_ids` lists to have
|
||||
The given dict expects the provided `ids` and `tokens` lists to have
|
||||
the same length.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::convert::TryInto;
|
||||
use std::sync::Arc;
|
||||
|
||||
use pyo3::exceptions;
|
||||
@@ -200,17 +201,13 @@ impl FromPyObject<'_> for PySpecialToken {
|
||||
.get_item("ids")
|
||||
.ok_or_else(|| exceptions::PyValueError::new_err("`ids` must be specified"))?
|
||||
.extract::<Vec<u32>>()?;
|
||||
let type_ids = d.get_item("type_ids").map_or_else(
|
||||
|| Ok(vec![None; ids.len()]),
|
||||
|v| v.extract::<Vec<Option<u32>>>(),
|
||||
)?;
|
||||
let tokens = d
|
||||
.get_item("tokens")
|
||||
.ok_or_else(|| exceptions::PyValueError::new_err("`tokens` must be specified"))?
|
||||
.extract::<Vec<String>>()?;
|
||||
|
||||
Ok(Self(
|
||||
ToPyResult(SpecialToken::new(id, ids, type_ids, tokens)).into_py()?,
|
||||
ToPyResult(SpecialToken::new(id, ids, tokens)).into_py()?,
|
||||
))
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
@@ -232,9 +229,13 @@ impl From<PyTemplate> for Template {
|
||||
impl FromPyObject<'_> for PyTemplate {
|
||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||
if let Ok(s) = ob.extract::<&str>() {
|
||||
Ok(Self(s.into()))
|
||||
Ok(Self(
|
||||
s.try_into().map_err(exceptions::PyValueError::new_err)?,
|
||||
))
|
||||
} else if let Ok(s) = ob.extract::<Vec<&str>>() {
|
||||
Ok(Self(s.into()))
|
||||
Ok(Self(
|
||||
s.try_into().map_err(exceptions::PyValueError::new_err)?,
|
||||
))
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Expected Union[str, List[str]]",
|
||||
@@ -248,19 +249,19 @@ pub struct PyTemplateProcessing {}
|
||||
#[pymethods]
|
||||
impl PyTemplateProcessing {
|
||||
#[new]
|
||||
#[args(seq_a = "None", seq_b = "None", special_tokens = "None")]
|
||||
#[args(single = "None", pair = "None", special_tokens = "None")]
|
||||
fn new(
|
||||
seq_a: Option<PyTemplate>,
|
||||
seq_b: Option<PyTemplate>,
|
||||
single: Option<PyTemplate>,
|
||||
pair: Option<PyTemplate>,
|
||||
special_tokens: Option<Vec<PySpecialToken>>,
|
||||
) -> PyResult<(Self, PyPostProcessor)> {
|
||||
let mut builder = tk::processors::template::TemplateProcessing::builder();
|
||||
|
||||
if let Some(seq) = seq_a {
|
||||
builder.sequence_a(seq);
|
||||
if let Some(seq) = single {
|
||||
builder.single(seq.into());
|
||||
}
|
||||
if let Some(seq) = seq_b {
|
||||
builder.sequence_b(seq);
|
||||
if let Some(seq) = pair {
|
||||
builder.pair(seq.into());
|
||||
}
|
||||
if let Some(sp) = special_tokens {
|
||||
builder.special_tokens(sp);
|
||||
|
||||
@@ -88,15 +88,15 @@ class TestByteLevelProcessing:
|
||||
class TestTemplateProcessing:
|
||||
def get_bert(self):
|
||||
return TemplateProcessing(
|
||||
seq_a=["[CLS]", "$0", "[SEP]"],
|
||||
seq_b=["$1", "[SEP]"],
|
||||
single=["[CLS]", "$0", "[SEP]"],
|
||||
pair=["[CLS]", "$A", "[SEP]", "$B:1", "[SEP]:1"],
|
||||
special_tokens=[("[CLS]", 1), ("[SEP]", 0)],
|
||||
)
|
||||
|
||||
def get_roberta(self):
|
||||
return TemplateProcessing(
|
||||
seq_a="<s> $0 </s>",
|
||||
seq_b="</s> $0 </s>",
|
||||
single="<s> $0 </s>",
|
||||
pair="<s> $A </s> </s> $B </s>",
|
||||
special_tokens=[("<s>", 0), ("</s>", 1)],
|
||||
)
|
||||
|
||||
@@ -113,19 +113,17 @@ class TestTemplateProcessing:
|
||||
# [822, 10]
|
||||
|
||||
return TemplateProcessing(
|
||||
seq_a=["Q", "$0"],
|
||||
seq_b=["C", "$1"],
|
||||
single=["$0"],
|
||||
pair=["Q", "$A", "C", "$B"],
|
||||
special_tokens=[
|
||||
{
|
||||
"id": "Q",
|
||||
"ids": [2625, 10],
|
||||
"type_ids": [None, None],
|
||||
"tokens": ["_question", ":"],
|
||||
},
|
||||
{
|
||||
"id": "C",
|
||||
"ids": [822, 10],
|
||||
"type_ids": [None, None],
|
||||
"tokens": ["_context", ":"],
|
||||
},
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user