mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 05:08:24 +00:00
Roberta PostProcessor (#111)
* Added RobertaProcessor on Rust side. Required to match the double separator token in the middle of pairs. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Fix typo in RobertaProcessing method declaration Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Correctly include RobertProcessor in the Python binding Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Roberta doesnt use token_type_ids so let's set everything to 0 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Attempt to make it works on Node side too. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * fix js bindings / `npm run lint` * Make RustFmt happy. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> Co-authored-by: Pierric Cistac <Pierrci@users.noreply.github.com>
This commit is contained in:
@@ -59,6 +59,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<processors::PostProcessor>()?;
|
||||
m.add_class::<processors::BertProcessing>()?;
|
||||
m.add_class::<processors::RobertaProcessing>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -21,3 +21,18 @@ impl BertProcessing {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[pyclass]
|
||||
pub struct RobertaProcessing {}
|
||||
#[pymethods]
|
||||
impl RobertaProcessing {
|
||||
#[staticmethod]
|
||||
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<PostProcessor> {
|
||||
Ok(PostProcessor {
|
||||
processor: Container::Owned(Box::new(tk::processors::roberta::RobertaProcessing::new(
|
||||
sep, cls,
|
||||
))),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,3 +2,4 @@ from .. import processors
|
||||
|
||||
PostProcessor = processors.PostProcessor
|
||||
BertProcessing = processors.BertProcessing
|
||||
RobertaProcessing = processors.RobertaProcessing
|
||||
@@ -31,3 +31,29 @@ class BertProcessing:
|
||||
PostProcessor
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RobertaProcessing:
|
||||
""" RobertaProcessing
|
||||
|
||||
This post-processor takes care of adding the special tokens needed by
|
||||
a Roberta model:
|
||||
- a SEP token
|
||||
- a CLS token
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor:
|
||||
""" Instantiate a new RobertaProcessing with the given tokens
|
||||
|
||||
Args:
|
||||
sep: Tuple[str, int]:
|
||||
A tuple with the string representation of the SEP token, and its id
|
||||
|
||||
cls: Tuple[str, int]:
|
||||
A tuple with the string representation of the CLS token, and its id
|
||||
|
||||
Returns:
|
||||
PostProcessor
|
||||
"""
|
||||
pass
|
||||
Reference in New Issue
Block a user