mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Roberta PostProcessor (#111)
* Added RobertaProcessor on Rust side. Required to match the double separator token in the middle of pairs. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Fix typo in RobertaProcessing method declaration Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Correctly include RobertProcessor in the Python binding Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Roberta doesnt use token_type_ids so let's set everything to 0 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Attempt to make it works on Node side too. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * fix js bindings / `npm run lint` * Make RustFmt happy. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> Co-authored-by: Pierric Cistac <Pierrci@users.noreply.github.com>
This commit is contained in:
11
bindings/node/lib/bindings/post-processors.d.ts
vendored
11
bindings/node/lib/bindings/post-processors.d.ts
vendored
@ -15,3 +15,14 @@ export function bertProcessing(
|
||||
sep: [string, number],
|
||||
cls: [string, number]
|
||||
): PostProcessor;
|
||||
|
||||
/**
|
||||
* Instantiate a new RobertaProcessing with the given tokens
|
||||
*
|
||||
* @param sep A tuple with the string representation of the SEP token, and its id
|
||||
* @param cls A tuple with the string representation of the CLS token, and its id
|
||||
*/
|
||||
export function robertaProcessing(
|
||||
sep: [string, number],
|
||||
cls: [string, number]
|
||||
): PostProcessor;
|
||||
|
@ -1,5 +1,6 @@
|
||||
const native = require("./native");
|
||||
|
||||
module.exports = {
|
||||
bertProcessing: native.processors_BertProcessing
|
||||
bertProcessing: native.processors_BertProcessing,
|
||||
robertaProcessing: native.processors_RobertaProcessing
|
||||
};
|
||||
|
@ -55,8 +55,45 @@ fn bert_processing(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
|
||||
Ok(processor)
|
||||
}
|
||||
|
||||
/// roberta_processing(sep: [String, number], cls: [String, number])
|
||||
fn roberta_processing(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
|
||||
let sep = cx.argument::<JsArray>(0)?;
|
||||
let cls = cx.argument::<JsArray>(1)?;
|
||||
if sep.len() != 2 || cls.len() != 2 {
|
||||
return cx.throw_error("SEP and CLS must be of the form: [String, number]");
|
||||
}
|
||||
let sep: (String, u32) = (
|
||||
sep.get(&mut cx, 0)?
|
||||
.downcast::<JsString>()
|
||||
.or_throw(&mut cx)?
|
||||
.value(),
|
||||
sep.get(&mut cx, 1)?
|
||||
.downcast::<JsNumber>()
|
||||
.or_throw(&mut cx)?
|
||||
.value() as u32,
|
||||
);
|
||||
let cls: (String, u32) = (
|
||||
cls.get(&mut cx, 0)?
|
||||
.downcast::<JsString>()
|
||||
.or_throw(&mut cx)?
|
||||
.value(),
|
||||
cls.get(&mut cx, 1)?
|
||||
.downcast::<JsNumber>()
|
||||
.or_throw(&mut cx)?
|
||||
.value() as u32,
|
||||
);
|
||||
|
||||
let mut processor = JsPostProcessor::new::<_, JsPostProcessor, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
processor.borrow_mut(&guard).processor.to_owned(Box::new(
|
||||
tk::processors::roberta::RobertaProcessing::new(sep, cls),
|
||||
));
|
||||
Ok(processor)
|
||||
}
|
||||
|
||||
/// Register everything here
|
||||
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||
m.export_function(&format!("{}_BertProcessing", prefix), bert_processing)?;
|
||||
m.export_function(&format!("{}_RobertaProcessing", prefix), roberta_processing)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -59,6 +59,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<processors::PostProcessor>()?;
|
||||
m.add_class::<processors::BertProcessing>()?;
|
||||
m.add_class::<processors::RobertaProcessing>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -21,3 +21,18 @@ impl BertProcessing {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[pyclass]
|
||||
pub struct RobertaProcessing {}
|
||||
#[pymethods]
|
||||
impl RobertaProcessing {
|
||||
#[staticmethod]
|
||||
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<PostProcessor> {
|
||||
Ok(PostProcessor {
|
||||
processor: Container::Owned(Box::new(tk::processors::roberta::RobertaProcessing::new(
|
||||
sep, cls,
|
||||
))),
|
||||
})
|
||||
}
|
||||
}
|
@ -2,3 +2,4 @@ from .. import processors
|
||||
|
||||
PostProcessor = processors.PostProcessor
|
||||
BertProcessing = processors.BertProcessing
|
||||
RobertaProcessing = processors.RobertaProcessing
|
@ -31,3 +31,29 @@ class BertProcessing:
|
||||
PostProcessor
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RobertaProcessing:
|
||||
""" RobertaProcessing
|
||||
|
||||
This post-processor takes care of adding the special tokens needed by
|
||||
a Roberta model:
|
||||
- a SEP token
|
||||
- a CLS token
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor:
|
||||
""" Instantiate a new RobertaProcessing with the given tokens
|
||||
|
||||
Args:
|
||||
sep: Tuple[str, int]:
|
||||
A tuple with the string representation of the SEP token, and its id
|
||||
|
||||
cls: Tuple[str, int]:
|
||||
A tuple with the string representation of the CLS token, and its id
|
||||
|
||||
Returns:
|
||||
PostProcessor
|
||||
"""
|
||||
pass
|
@ -1 +1,2 @@
|
||||
pub mod bert;
|
||||
pub mod roberta;
|
||||
|
81
tokenizers/src/processors/roberta.rs
Normal file
81
tokenizers/src/processors/roberta.rs
Normal file
@ -0,0 +1,81 @@
|
||||
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
||||
|
||||
pub struct RobertaProcessing {
|
||||
sep: (String, u32),
|
||||
cls: (String, u32),
|
||||
}
|
||||
|
||||
impl RobertaProcessing {
|
||||
pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
|
||||
RobertaProcessing { sep, cls }
|
||||
}
|
||||
}
|
||||
|
||||
impl PostProcessor for RobertaProcessing {
|
||||
fn added_tokens(
|
||||
&self,
|
||||
_encoding: &Encoding,
|
||||
pair_encoding: &Option<Encoding>,
|
||||
) -> Result<usize> {
|
||||
if pair_encoding.is_some() {
|
||||
Ok(4)
|
||||
} else {
|
||||
Ok(2)
|
||||
}
|
||||
}
|
||||
|
||||
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
|
||||
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
|
||||
let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat();
|
||||
let tokens = [
|
||||
&[self.cls.0.clone()],
|
||||
&encoding.get_tokens()[..],
|
||||
&[self.sep.0.clone()],
|
||||
]
|
||||
.concat();
|
||||
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
|
||||
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
||||
let attention_mask = vec![1; ids.len()];
|
||||
|
||||
let mut new_encoding = Encoding::new(
|
||||
encoding.get_normalized().clone(),
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
offsets,
|
||||
special_tokens,
|
||||
attention_mask,
|
||||
encoding.take_overflowing(),
|
||||
);
|
||||
|
||||
if let Some(mut encoding) = pair_encoding {
|
||||
let pair_ids = [&[self.sep.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
|
||||
let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
|
||||
let pair_tokens = [
|
||||
&[self.sep.0.clone()],
|
||||
&encoding.get_tokens()[..],
|
||||
&[self.sep.0.clone()],
|
||||
]
|
||||
.concat();
|
||||
let pair_offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
|
||||
let pair_special_tokens =
|
||||
[&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||
|
||||
let new_pair_encoding = Encoding::new(
|
||||
encoding.get_normalized().clone(),
|
||||
pair_ids,
|
||||
pair_type_ids,
|
||||
pair_tokens,
|
||||
pair_offsets,
|
||||
pair_special_tokens,
|
||||
pair_attention_mask,
|
||||
encoding.take_overflowing(),
|
||||
);
|
||||
|
||||
new_encoding.merge_with(new_pair_encoding);
|
||||
}
|
||||
|
||||
Ok(new_encoding)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user