mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-08 13:48:19 +00:00
Python - Update bindings for TemplateProcessing
This commit is contained in:
@@ -119,22 +119,29 @@ class TemplateProcessing(PostProcessor):
|
|||||||
sequences. The final result looks like this:
|
sequences. The final result looks like this:
|
||||||
- Single sequence: `[CLS] Hello there [SEP]`
|
- Single sequence: `[CLS] Hello there [SEP]`
|
||||||
- Pair sequences: `[CLS] My name is Anthony [SEP] What is my name? [SEP]`
|
- Pair sequences: `[CLS] My name is Anthony [SEP] What is my name? [SEP]`
|
||||||
|
With the type ids as following:
|
||||||
|
```markdown
|
||||||
|
[CLS] ... [SEP] ... [SEP]
|
||||||
|
0 0 0 1 1
|
||||||
|
```
|
||||||
|
|
||||||
You can achieve such behavior using a TemplateProcessing:
|
You can achieve such behavior using a TemplateProcessing:
|
||||||
```
|
```
|
||||||
TemplateProcessing(
|
TemplateProcessing(
|
||||||
seq_a="[CLS] $0 [SEP]",
|
single="[CLS] $0 [SEP]",
|
||||||
seq_b="$1 [SEP]",
|
pair="[CLS] $A [SEP] $B:1 [SEP]:1",
|
||||||
special_tokens=[("[CLS]", 1), ("[SEP]", 0)],
|
special_tokens=[("[CLS]", 1), ("[SEP]", 0)],
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
In this example, $0 and $1 both represent the input sequences. The number in this
|
In this example, each input sequence is identified using a `$` construct. This identifier
|
||||||
identifier is actually the default type_id that will be used for each sequence. So,
|
lets us specify each input sequence, and the type_id to use. When nothing is specified,
|
||||||
in this case, the first sequence will use 0, while the pair sequence will use 1.
|
it uses the default values. Here are the different ways to specify it:
|
||||||
|
- Specifying the sequence, with default `type_id == 0`: `$A` or `$B`
|
||||||
|
- Specifying the `type_id` with default `sequence == A`: `$0`, `$1`, `$2`, ...
|
||||||
|
- Specifying both: `$A:0`, `$B:1`, ...
|
||||||
|
|
||||||
Note that we are saying the "default" type_id because each SpecialToken can define
|
The same construct is used for special tokens: `<identifier>(:<type_id>)?`.
|
||||||
its own type_id which would override the provided default.
|
|
||||||
|
|
||||||
**Warning**: You must ensure that you are giving the correct tokens/ids as these
|
**Warning**: You must ensure that you are giving the correct tokens/ids as these
|
||||||
will be added to the Encoding without any further check. If the given ids correspond
|
will be added to the Encoding without any further check. If the given ids correspond
|
||||||
@@ -142,15 +149,15 @@ class TemplateProcessing(PostProcessor):
|
|||||||
might lead to unexpected results.
|
might lead to unexpected results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, seq_a: Template, seq_b: Template, special_tokens: Tokens) -> None:
|
def __init__(self, single: Template, seq_b: Template, special_tokens: Tokens) -> None:
|
||||||
"""Instantiate a new TemplateProcessing
|
"""Instantiate a new TemplateProcessing
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq_a: Template
|
single: Template
|
||||||
The template for the first sequence.
|
The template used for single sequences
|
||||||
|
|
||||||
seq_b: Template:
|
pair: Template:
|
||||||
The template for the pair sequence.
|
The template used when both sequences are specified
|
||||||
|
|
||||||
special_tokens: Tokens:
|
special_tokens: Tokens:
|
||||||
The list of special tokens used in each sequences
|
The list of special tokens used in each sequences
|
||||||
@@ -165,10 +172,7 @@ class TemplateProcessing(PostProcessor):
|
|||||||
- "id": str => The special token id, as specified in the Template
|
- "id": str => The special token id, as specified in the Template
|
||||||
- "ids": List[int] => The associated IDs
|
- "ids": List[int] => The associated IDs
|
||||||
- "tokens": List[str] => The associated tokens
|
- "tokens": List[str] => The associated tokens
|
||||||
- "type_ids": Optional[List[Optional[int]]] => If specified, a list of optional
|
The given dict expects the provided `ids` and `tokens` lists to have
|
||||||
type_ids. In the `type_id` is not specified, the one from the input sequence
|
|
||||||
will be used.
|
|
||||||
The given dict expects the provided `ids`, `tokens` and `type_ids` lists to have
|
|
||||||
the same length.
|
the same length.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use std::convert::TryInto;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use pyo3::exceptions;
|
use pyo3::exceptions;
|
||||||
@@ -200,17 +201,13 @@ impl FromPyObject<'_> for PySpecialToken {
|
|||||||
.get_item("ids")
|
.get_item("ids")
|
||||||
.ok_or_else(|| exceptions::PyValueError::new_err("`ids` must be specified"))?
|
.ok_or_else(|| exceptions::PyValueError::new_err("`ids` must be specified"))?
|
||||||
.extract::<Vec<u32>>()?;
|
.extract::<Vec<u32>>()?;
|
||||||
let type_ids = d.get_item("type_ids").map_or_else(
|
|
||||||
|| Ok(vec![None; ids.len()]),
|
|
||||||
|v| v.extract::<Vec<Option<u32>>>(),
|
|
||||||
)?;
|
|
||||||
let tokens = d
|
let tokens = d
|
||||||
.get_item("tokens")
|
.get_item("tokens")
|
||||||
.ok_or_else(|| exceptions::PyValueError::new_err("`tokens` must be specified"))?
|
.ok_or_else(|| exceptions::PyValueError::new_err("`tokens` must be specified"))?
|
||||||
.extract::<Vec<String>>()?;
|
.extract::<Vec<String>>()?;
|
||||||
|
|
||||||
Ok(Self(
|
Ok(Self(
|
||||||
ToPyResult(SpecialToken::new(id, ids, type_ids, tokens)).into_py()?,
|
ToPyResult(SpecialToken::new(id, ids, tokens)).into_py()?,
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
@@ -232,9 +229,13 @@ impl From<PyTemplate> for Template {
|
|||||||
impl FromPyObject<'_> for PyTemplate {
|
impl FromPyObject<'_> for PyTemplate {
|
||||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||||
if let Ok(s) = ob.extract::<&str>() {
|
if let Ok(s) = ob.extract::<&str>() {
|
||||||
Ok(Self(s.into()))
|
Ok(Self(
|
||||||
|
s.try_into().map_err(exceptions::PyValueError::new_err)?,
|
||||||
|
))
|
||||||
} else if let Ok(s) = ob.extract::<Vec<&str>>() {
|
} else if let Ok(s) = ob.extract::<Vec<&str>>() {
|
||||||
Ok(Self(s.into()))
|
Ok(Self(
|
||||||
|
s.try_into().map_err(exceptions::PyValueError::new_err)?,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
"Expected Union[str, List[str]]",
|
"Expected Union[str, List[str]]",
|
||||||
@@ -248,19 +249,19 @@ pub struct PyTemplateProcessing {}
|
|||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyTemplateProcessing {
|
impl PyTemplateProcessing {
|
||||||
#[new]
|
#[new]
|
||||||
#[args(seq_a = "None", seq_b = "None", special_tokens = "None")]
|
#[args(single = "None", pair = "None", special_tokens = "None")]
|
||||||
fn new(
|
fn new(
|
||||||
seq_a: Option<PyTemplate>,
|
single: Option<PyTemplate>,
|
||||||
seq_b: Option<PyTemplate>,
|
pair: Option<PyTemplate>,
|
||||||
special_tokens: Option<Vec<PySpecialToken>>,
|
special_tokens: Option<Vec<PySpecialToken>>,
|
||||||
) -> PyResult<(Self, PyPostProcessor)> {
|
) -> PyResult<(Self, PyPostProcessor)> {
|
||||||
let mut builder = tk::processors::template::TemplateProcessing::builder();
|
let mut builder = tk::processors::template::TemplateProcessing::builder();
|
||||||
|
|
||||||
if let Some(seq) = seq_a {
|
if let Some(seq) = single {
|
||||||
builder.sequence_a(seq);
|
builder.single(seq.into());
|
||||||
}
|
}
|
||||||
if let Some(seq) = seq_b {
|
if let Some(seq) = pair {
|
||||||
builder.sequence_b(seq);
|
builder.pair(seq.into());
|
||||||
}
|
}
|
||||||
if let Some(sp) = special_tokens {
|
if let Some(sp) = special_tokens {
|
||||||
builder.special_tokens(sp);
|
builder.special_tokens(sp);
|
||||||
|
|||||||
@@ -88,15 +88,15 @@ class TestByteLevelProcessing:
|
|||||||
class TestTemplateProcessing:
|
class TestTemplateProcessing:
|
||||||
def get_bert(self):
|
def get_bert(self):
|
||||||
return TemplateProcessing(
|
return TemplateProcessing(
|
||||||
seq_a=["[CLS]", "$0", "[SEP]"],
|
single=["[CLS]", "$0", "[SEP]"],
|
||||||
seq_b=["$1", "[SEP]"],
|
pair=["[CLS]", "$A", "[SEP]", "$B:1", "[SEP]:1"],
|
||||||
special_tokens=[("[CLS]", 1), ("[SEP]", 0)],
|
special_tokens=[("[CLS]", 1), ("[SEP]", 0)],
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_roberta(self):
|
def get_roberta(self):
|
||||||
return TemplateProcessing(
|
return TemplateProcessing(
|
||||||
seq_a="<s> $0 </s>",
|
single="<s> $0 </s>",
|
||||||
seq_b="</s> $0 </s>",
|
pair="<s> $A </s> </s> $B </s>",
|
||||||
special_tokens=[("<s>", 0), ("</s>", 1)],
|
special_tokens=[("<s>", 0), ("</s>", 1)],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -113,19 +113,17 @@ class TestTemplateProcessing:
|
|||||||
# [822, 10]
|
# [822, 10]
|
||||||
|
|
||||||
return TemplateProcessing(
|
return TemplateProcessing(
|
||||||
seq_a=["Q", "$0"],
|
single=["$0"],
|
||||||
seq_b=["C", "$1"],
|
pair=["Q", "$A", "C", "$B"],
|
||||||
special_tokens=[
|
special_tokens=[
|
||||||
{
|
{
|
||||||
"id": "Q",
|
"id": "Q",
|
||||||
"ids": [2625, 10],
|
"ids": [2625, 10],
|
||||||
"type_ids": [None, None],
|
|
||||||
"tokens": ["_question", ":"],
|
"tokens": ["_question", ":"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "C",
|
"id": "C",
|
||||||
"ids": [822, 10],
|
"ids": [822, 10],
|
||||||
"type_ids": [None, None],
|
|
||||||
"tokens": ["_context", ":"],
|
"tokens": ["_context", ":"],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
|||||||
Reference in New Issue
Block a user