Python - Update bindings for TemplateProcessing

This commit is contained in:
Anthony MOI
2020-09-28 14:22:39 -04:00
committed by Anthony MOI
parent 2546974bfc
commit 1070eb471e
3 changed files with 41 additions and 38 deletions

View File

@@ -119,22 +119,29 @@ class TemplateProcessing(PostProcessor):
sequences. The final result looks like this:
- Single sequence: `[CLS] Hello there [SEP]`
- Pair sequences: `[CLS] My name is Anthony [SEP] What is my name? [SEP]`
With the type ids as following:
```markdown
[CLS] ... [SEP] ... [SEP]
0 0 0 1 1
```
You can achieve such behavior using a TemplateProcessing:
```
TemplateProcessing(
seq_a="[CLS] $0 [SEP]",
seq_b="$1 [SEP]",
single="[CLS] $0 [SEP]",
pair="[CLS] $A [SEP] $B:1 [SEP]:1",
special_tokens=[("[CLS]", 1), ("[SEP]", 0)],
)
```
In this example, $0 and $1 both represent the input sequences. The number in this
identifier is actually the default type_id that will be used for each sequence. So,
in this case, the first sequence will use 0, while the pair sequence will use 1.
In this example, each input sequence is identified using a `$` construct. This identifier
lets us specify each input sequence, and the type_id to use. When nothing is specified,
it uses the default values. Here are the different ways to specify it:
- Specifying the sequence, with default `type_id == 0`: `$A` or `$B`
- Specifying the `type_id` with default `sequence == A`: `$0`, `$1`, `$2`, ...
- Specifying both: `$A:0`, `$B:1`, ...
Note that we are saying the "default" type_id because each SpecialToken can define
its own type_id which would override the provided default.
The same construct is used for special tokens: `<identifier>(:<type_id>)?`.
**Warning**: You must ensure that you are giving the correct tokens/ids as these
will be added to the Encoding without any further check. If the given ids correspond
@@ -142,15 +149,15 @@ class TemplateProcessing(PostProcessor):
might lead to unexpected results.
"""
def __init__(self, seq_a: Template, seq_b: Template, special_tokens: Tokens) -> None:
def __init__(self, single: Template, seq_b: Template, special_tokens: Tokens) -> None:
"""Instantiate a new TemplateProcessing
Args:
seq_a: Template
The template for the first sequence.
single: Template
The template used for single sequences
seq_b: Template:
The template for the pair sequence.
pair: Template:
The template used when both sequences are specified
special_tokens: Tokens:
The list of special tokens used in each sequences
@@ -165,10 +172,7 @@ class TemplateProcessing(PostProcessor):
- "id": str => The special token id, as specified in the Template
- "ids": List[int] => The associated IDs
- "tokens": List[str] => The associated tokens
- "type_ids": Optional[List[Optional[int]]] => If specified, a list of optional
type_ids. In the `type_id` is not specified, the one from the input sequence
will be used.
The given dict expects the provided `ids`, `tokens` and `type_ids` lists to have
The given dict expects the provided `ids` and `tokens` lists to have
the same length.
"""
pass

View File

@@ -1,3 +1,4 @@
use std::convert::TryInto;
use std::sync::Arc;
use pyo3::exceptions;
@@ -200,17 +201,13 @@ impl FromPyObject<'_> for PySpecialToken {
.get_item("ids")
.ok_or_else(|| exceptions::PyValueError::new_err("`ids` must be specified"))?
.extract::<Vec<u32>>()?;
let type_ids = d.get_item("type_ids").map_or_else(
|| Ok(vec![None; ids.len()]),
|v| v.extract::<Vec<Option<u32>>>(),
)?;
let tokens = d
.get_item("tokens")
.ok_or_else(|| exceptions::PyValueError::new_err("`tokens` must be specified"))?
.extract::<Vec<String>>()?;
Ok(Self(
ToPyResult(SpecialToken::new(id, ids, type_ids, tokens)).into_py()?,
ToPyResult(SpecialToken::new(id, ids, tokens)).into_py()?,
))
} else {
Err(exceptions::PyTypeError::new_err(
@@ -232,9 +229,13 @@ impl From<PyTemplate> for Template {
impl FromPyObject<'_> for PyTemplate {
fn extract(ob: &PyAny) -> PyResult<Self> {
if let Ok(s) = ob.extract::<&str>() {
Ok(Self(s.into()))
Ok(Self(
s.try_into().map_err(exceptions::PyValueError::new_err)?,
))
} else if let Ok(s) = ob.extract::<Vec<&str>>() {
Ok(Self(s.into()))
Ok(Self(
s.try_into().map_err(exceptions::PyValueError::new_err)?,
))
} else {
Err(exceptions::PyTypeError::new_err(
"Expected Union[str, List[str]]",
@@ -248,19 +249,19 @@ pub struct PyTemplateProcessing {}
#[pymethods]
impl PyTemplateProcessing {
#[new]
#[args(seq_a = "None", seq_b = "None", special_tokens = "None")]
#[args(single = "None", pair = "None", special_tokens = "None")]
fn new(
seq_a: Option<PyTemplate>,
seq_b: Option<PyTemplate>,
single: Option<PyTemplate>,
pair: Option<PyTemplate>,
special_tokens: Option<Vec<PySpecialToken>>,
) -> PyResult<(Self, PyPostProcessor)> {
let mut builder = tk::processors::template::TemplateProcessing::builder();
if let Some(seq) = seq_a {
builder.sequence_a(seq);
if let Some(seq) = single {
builder.single(seq.into());
}
if let Some(seq) = seq_b {
builder.sequence_b(seq);
if let Some(seq) = pair {
builder.pair(seq.into());
}
if let Some(sp) = special_tokens {
builder.special_tokens(sp);

View File

@@ -88,15 +88,15 @@ class TestByteLevelProcessing:
class TestTemplateProcessing:
def get_bert(self):
return TemplateProcessing(
seq_a=["[CLS]", "$0", "[SEP]"],
seq_b=["$1", "[SEP]"],
single=["[CLS]", "$0", "[SEP]"],
pair=["[CLS]", "$A", "[SEP]", "$B:1", "[SEP]:1"],
special_tokens=[("[CLS]", 1), ("[SEP]", 0)],
)
def get_roberta(self):
return TemplateProcessing(
seq_a="<s> $0 </s>",
seq_b="</s> $0 </s>",
single="<s> $0 </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[("<s>", 0), ("</s>", 1)],
)
@@ -113,19 +113,17 @@ class TestTemplateProcessing:
# [822, 10]
return TemplateProcessing(
seq_a=["Q", "$0"],
seq_b=["C", "$1"],
single=["$0"],
pair=["Q", "$A", "C", "$B"],
special_tokens=[
{
"id": "Q",
"ids": [2625, 10],
"type_ids": [None, None],
"tokens": ["_question", ":"],
},
{
"id": "C",
"ids": [822, 10],
"type_ids": [None, None],
"tokens": ["_context", ":"],
},
],