Expose post_process on the Tokenizer

This commit is contained in:
Anthony MOI
2020-03-25 17:54:48 -04:00
parent 9ce895550b
commit 9bd9e0b3c1
6 changed files with 126 additions and 2 deletions

View File

@ -2,6 +2,7 @@ extern crate tokenizers as tk;
use crate::container::Container;
use crate::decoders::JsDecoder;
use crate::encoding::JsEncoding;
use crate::models::JsModel;
use crate::normalizers::JsNormalizer;
use crate::pre_tokenizers::JsPreTokenizer;
@ -712,6 +713,60 @@ declare_types! {
Ok(cx.undefined().upcast())
}
method postProcess(mut cx) {
// postProcess(
// encoding: Encoding,
// pair?: Encoding,
// addSpecialTokens: boolean = true
// ): Encoding
let encoding = {
let encoding = cx.argument::<JsEncoding>(0)?;
let guard = cx.lock();
let encoding = encoding
.borrow(&guard)
.encoding
.execute(|e| *e.unwrap().clone());
encoding
};
let pair = cx.argument_opt(1).map(|item| {
if item.downcast::<JsUndefined>().is_ok() {
None
} else {
item.downcast::<JsEncoding>().map(|e| {
let guard = cx.lock();
let encoding = e.borrow(&guard)
.encoding
.execute(|e| *e.unwrap().clone());
encoding
}).ok()
}
}).flatten();
let add_special_tokens = cx
.argument_opt(2)
.map(|arg| Ok(arg.downcast::<JsBoolean>().or_throw(&mut cx)?.value()))
.unwrap_or(Ok(true))?;
let encoding = {
let this = cx.this();
let guard = cx.lock();
let encoding = this.borrow(&guard)
.tokenizer.post_process(encoding, pair, add_special_tokens);
encoding
};
let encoding = encoding
.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?;
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_encoding
.borrow_mut(&guard)
.encoding
.to_owned(Box::new(encoding));
Ok(js_encoding.upcast())
}
method getModel(mut cx) {
// getModel(): Model

View File

@ -9,7 +9,7 @@ use tk::tokenizer::{Offsets, PaddingDirection};
#[pyclass(dict)]
#[repr(transparent)]
pub struct Encoding {
encoding: tk::tokenizer::Encoding,
pub encoding: tk::tokenizer::Encoding,
}
impl Encoding {

View File

@ -343,6 +343,25 @@ impl Tokenizer {
})
}
#[args(pair = "None", add_special_tokens = true)]
fn post_process(
&self,
encoding: &Encoding,
pair: Option<&Encoding>,
add_special_tokens: bool,
) -> PyResult<Encoding> {
ToPyResult(
self.tokenizer
.post_process(
encoding.encoding.clone(),
pair.map(|p| p.encoding.clone()),
add_special_tokens,
)
.map(Encoding::new),
)
.into()
}
#[getter]
fn get_model(&self) -> PyResult<Model> {
Ok(Model {

View File

@ -33,6 +33,7 @@ class Encoding:
Returns:
The resulting Encoding
"""
pass
@property
def ids(self) -> List[int]:
""" The tokenized ids """
@ -441,3 +442,27 @@ class Tokenizer:
The number of tokens that were added to the vocabulary
"""
pass
def post_process(
self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True
) -> Encoding:
""" Apply all the post-processing steps to the given encodings.
The various steps are:
1. Truncate according to global params (provided to `enable_truncation`)
2. Apply the PostProcessor
3. Pad according to global params. (provided to `enable_padding`)
Args:
encoding: Encoding:
The main Encoding to post process
pair: Optional[Encoding]:
An optional pair Encoding
add_special_tokens: bool:
Whether to add special tokens
Returns:
The resulting Encoding
"""
pass

View File

@ -314,3 +314,28 @@ class BaseTokenizer:
The name of the tokenizer, to be used in the saved files
"""
return self._tokenizer.model.save(directory, name=name)
def post_process(
self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True
) -> Encoding:
""" Apply all the post-processing steps to the given encodings.
The various steps are:
1. Truncate according to global params (provided to `enable_truncation`)
2. Apply the PostProcessor
3. Pad according to global params. (provided to `enable_padding`)
Args:
encoding: Encoding:
The main Encoding to post process
pair: Optional[Encoding]:
An optional pair Encoding
add_special_tokens: bool:
Whether to add special tokens
Returns:
The resulting Encoding
"""
return self._tokenizer.post_process(encoding, pair, add_special_tokens)

View File

@ -656,7 +656,7 @@ impl Tokenizer {
}
/// Post processing logic, handling the case where there is no PostProcessor set
fn post_process(
pub fn post_process(
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,