Roberta PostProcessor (#111)

* Added RobertaProcessor on Rust side.

Required to match the double separator token in the middle of pairs.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Fix typo in RobertaProcessing method declaration

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Correctly include RobertProcessor in the Python binding

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Roberta doesnt use token_type_ids so let's set everything to 0

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Attempt to make it works on Node side too.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* fix js bindings / `npm run lint`

* Make RustFmt happy.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

Co-authored-by: Pierric Cistac <Pierrci@users.noreply.github.com>
This commit is contained in:
Funtowicz Morgan
2020-02-03 10:39:48 +00:00
committed by GitHub
parent 8138ece1a6
commit 6524f09e99
9 changed files with 175 additions and 1 deletions

View File

@@ -59,6 +59,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<processors::PostProcessor>()?;
m.add_class::<processors::BertProcessing>()?;
m.add_class::<processors::RobertaProcessing>()?;
Ok(())
}

View File

@@ -21,3 +21,18 @@ impl BertProcessing {
})
}
}
#[pyclass]
pub struct RobertaProcessing {}
#[pymethods]
impl RobertaProcessing {
#[staticmethod]
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<PostProcessor> {
Ok(PostProcessor {
processor: Container::Owned(Box::new(tk::processors::roberta::RobertaProcessing::new(
sep, cls,
))),
})
}
}

View File

@@ -2,3 +2,4 @@ from .. import processors
PostProcessor = processors.PostProcessor
BertProcessing = processors.BertProcessing
RobertaProcessing = processors.RobertaProcessing

View File

@@ -31,3 +31,29 @@ class BertProcessing:
PostProcessor
"""
pass
class RobertaProcessing:
""" RobertaProcessing
This post-processor takes care of adding the special tokens needed by
a Roberta model:
- a SEP token
- a CLS token
"""
@staticmethod
def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor:
""" Instantiate a new RobertaProcessing with the given tokens
Args:
sep: Tuple[str, int]:
A tuple with the string representation of the SEP token, and its id
cls: Tuple[str, int]:
A tuple with the string representation of the CLS token, and its id
Returns:
PostProcessor
"""
pass