Roberta PostProcessor (#111)

* Added RobertaProcessor on Rust side.

Required to match the double separator token in the middle of pairs.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Fix typo in RobertaProcessing method declaration

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Correctly include RobertProcessor in the Python binding

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Roberta doesnt use token_type_ids so let's set everything to 0

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Attempt to make it works on Node side too.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* fix js bindings / `npm run lint`

* Make RustFmt happy.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

Co-authored-by: Pierric Cistac <Pierrci@users.noreply.github.com>
This commit is contained in:
Funtowicz Morgan
2020-02-03 10:39:48 +00:00
committed by GitHub
parent 8138ece1a6
commit 6524f09e99
9 changed files with 175 additions and 1 deletions

View File

@ -15,3 +15,14 @@ export function bertProcessing(
sep: [string, number],
cls: [string, number]
): PostProcessor;
/**
* Instantiate a new RobertaProcessing with the given tokens
*
* @param sep A tuple with the string representation of the SEP token, and its id
* @param cls A tuple with the string representation of the CLS token, and its id
*/
export function robertaProcessing(
sep: [string, number],
cls: [string, number]
): PostProcessor;

View File

@ -1,5 +1,6 @@
const native = require("./native");
module.exports = {
bertProcessing: native.processors_BertProcessing
bertProcessing: native.processors_BertProcessing,
robertaProcessing: native.processors_RobertaProcessing
};

View File

@ -55,8 +55,45 @@ fn bert_processing(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
Ok(processor)
}
/// roberta_processing(sep: [String, number], cls: [String, number])
fn roberta_processing(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
let sep = cx.argument::<JsArray>(0)?;
let cls = cx.argument::<JsArray>(1)?;
if sep.len() != 2 || cls.len() != 2 {
return cx.throw_error("SEP and CLS must be of the form: [String, number]");
}
let sep: (String, u32) = (
sep.get(&mut cx, 0)?
.downcast::<JsString>()
.or_throw(&mut cx)?
.value(),
sep.get(&mut cx, 1)?
.downcast::<JsNumber>()
.or_throw(&mut cx)?
.value() as u32,
);
let cls: (String, u32) = (
cls.get(&mut cx, 0)?
.downcast::<JsString>()
.or_throw(&mut cx)?
.value(),
cls.get(&mut cx, 1)?
.downcast::<JsNumber>()
.or_throw(&mut cx)?
.value() as u32,
);
let mut processor = JsPostProcessor::new::<_, JsPostProcessor, _>(&mut cx, vec![])?;
let guard = cx.lock();
processor.borrow_mut(&guard).processor.to_owned(Box::new(
tk::processors::roberta::RobertaProcessing::new(sep, cls),
));
Ok(processor)
}
/// Register everything here
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_BertProcessing", prefix), bert_processing)?;
m.export_function(&format!("{}_RobertaProcessing", prefix), roberta_processing)?;
Ok(())
}

View File

@ -59,6 +59,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<processors::PostProcessor>()?;
m.add_class::<processors::BertProcessing>()?;
m.add_class::<processors::RobertaProcessing>()?;
Ok(())
}

View File

@ -21,3 +21,18 @@ impl BertProcessing {
})
}
}
#[pyclass]
pub struct RobertaProcessing {}
#[pymethods]
impl RobertaProcessing {
#[staticmethod]
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<PostProcessor> {
Ok(PostProcessor {
processor: Container::Owned(Box::new(tk::processors::roberta::RobertaProcessing::new(
sep, cls,
))),
})
}
}

View File

@ -2,3 +2,4 @@ from .. import processors
PostProcessor = processors.PostProcessor
BertProcessing = processors.BertProcessing
RobertaProcessing = processors.RobertaProcessing

View File

@ -31,3 +31,29 @@ class BertProcessing:
PostProcessor
"""
pass
class RobertaProcessing:
""" RobertaProcessing
This post-processor takes care of adding the special tokens needed by
a Roberta model:
- a SEP token
- a CLS token
"""
@staticmethod
def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor:
""" Instantiate a new RobertaProcessing with the given tokens
Args:
sep: Tuple[str, int]:
A tuple with the string representation of the SEP token, and its id
cls: Tuple[str, int]:
A tuple with the string representation of the CLS token, and its id
Returns:
PostProcessor
"""
pass

View File

@ -1 +1,2 @@
pub mod bert;
pub mod roberta;

View File

@ -0,0 +1,81 @@
use crate::tokenizer::{Encoding, PostProcessor, Result};
pub struct RobertaProcessing {
sep: (String, u32),
cls: (String, u32),
}
impl RobertaProcessing {
pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
RobertaProcessing { sep, cls }
}
}
impl PostProcessor for RobertaProcessing {
fn added_tokens(
&self,
_encoding: &Encoding,
pair_encoding: &Option<Encoding>,
) -> Result<usize> {
if pair_encoding.is_some() {
Ok(4)
} else {
Ok(2)
}
}
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat();
let tokens = [
&[self.cls.0.clone()],
&encoding.get_tokens()[..],
&[self.sep.0.clone()],
]
.concat();
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()];
let mut new_encoding = Encoding::new(
encoding.get_normalized().clone(),
ids,
type_ids,
tokens,
offsets,
special_tokens,
attention_mask,
encoding.take_overflowing(),
);
if let Some(mut encoding) = pair_encoding {
let pair_ids = [&[self.sep.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
let pair_tokens = [
&[self.sep.0.clone()],
&encoding.get_tokens()[..],
&[self.sep.0.clone()],
]
.concat();
let pair_offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let pair_special_tokens =
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
let new_pair_encoding = Encoding::new(
encoding.get_normalized().clone(),
pair_ids,
pair_type_ids,
pair_tokens,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
encoding.take_overflowing(),
);
new_encoding.merge_with(new_pair_encoding);
}
Ok(new_encoding)
}
}