mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-31 12:39:21 +00:00
Expose post_process on the Tokenizer
This commit is contained in:
@ -2,6 +2,7 @@ extern crate tokenizers as tk;
|
||||
|
||||
use crate::container::Container;
|
||||
use crate::decoders::JsDecoder;
|
||||
use crate::encoding::JsEncoding;
|
||||
use crate::models::JsModel;
|
||||
use crate::normalizers::JsNormalizer;
|
||||
use crate::pre_tokenizers::JsPreTokenizer;
|
||||
@ -712,6 +713,60 @@ declare_types! {
|
||||
Ok(cx.undefined().upcast())
|
||||
}
|
||||
|
||||
method postProcess(mut cx) {
|
||||
// postProcess(
|
||||
// encoding: Encoding,
|
||||
// pair?: Encoding,
|
||||
// addSpecialTokens: boolean = true
|
||||
// ): Encoding
|
||||
|
||||
let encoding = {
|
||||
let encoding = cx.argument::<JsEncoding>(0)?;
|
||||
let guard = cx.lock();
|
||||
let encoding = encoding
|
||||
.borrow(&guard)
|
||||
.encoding
|
||||
.execute(|e| *e.unwrap().clone());
|
||||
encoding
|
||||
};
|
||||
let pair = cx.argument_opt(1).map(|item| {
|
||||
if item.downcast::<JsUndefined>().is_ok() {
|
||||
None
|
||||
} else {
|
||||
item.downcast::<JsEncoding>().map(|e| {
|
||||
let guard = cx.lock();
|
||||
let encoding = e.borrow(&guard)
|
||||
.encoding
|
||||
.execute(|e| *e.unwrap().clone());
|
||||
encoding
|
||||
}).ok()
|
||||
}
|
||||
}).flatten();
|
||||
let add_special_tokens = cx
|
||||
.argument_opt(2)
|
||||
.map(|arg| Ok(arg.downcast::<JsBoolean>().or_throw(&mut cx)?.value()))
|
||||
.unwrap_or(Ok(true))?;
|
||||
|
||||
let encoding = {
|
||||
let this = cx.this();
|
||||
let guard = cx.lock();
|
||||
let encoding = this.borrow(&guard)
|
||||
.tokenizer.post_process(encoding, pair, add_special_tokens);
|
||||
encoding
|
||||
};
|
||||
let encoding = encoding
|
||||
.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?;
|
||||
|
||||
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
js_encoding
|
||||
.borrow_mut(&guard)
|
||||
.encoding
|
||||
.to_owned(Box::new(encoding));
|
||||
|
||||
Ok(js_encoding.upcast())
|
||||
}
|
||||
|
||||
method getModel(mut cx) {
|
||||
// getModel(): Model
|
||||
|
||||
|
@ -9,7 +9,7 @@ use tk::tokenizer::{Offsets, PaddingDirection};
|
||||
#[pyclass(dict)]
|
||||
#[repr(transparent)]
|
||||
pub struct Encoding {
|
||||
encoding: tk::tokenizer::Encoding,
|
||||
pub encoding: tk::tokenizer::Encoding,
|
||||
}
|
||||
|
||||
impl Encoding {
|
||||
|
@ -343,6 +343,25 @@ impl Tokenizer {
|
||||
})
|
||||
}
|
||||
|
||||
#[args(pair = "None", add_special_tokens = true)]
|
||||
fn post_process(
|
||||
&self,
|
||||
encoding: &Encoding,
|
||||
pair: Option<&Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> PyResult<Encoding> {
|
||||
ToPyResult(
|
||||
self.tokenizer
|
||||
.post_process(
|
||||
encoding.encoding.clone(),
|
||||
pair.map(|p| p.encoding.clone()),
|
||||
add_special_tokens,
|
||||
)
|
||||
.map(Encoding::new),
|
||||
)
|
||||
.into()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_model(&self) -> PyResult<Model> {
|
||||
Ok(Model {
|
||||
|
@ -33,6 +33,7 @@ class Encoding:
|
||||
Returns:
|
||||
The resulting Encoding
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def ids(self) -> List[int]:
|
||||
""" The tokenized ids """
|
||||
@ -441,3 +442,27 @@ class Tokenizer:
|
||||
The number of tokens that were added to the vocabulary
|
||||
"""
|
||||
pass
|
||||
def post_process(
|
||||
self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True
|
||||
) -> Encoding:
|
||||
""" Apply all the post-processing steps to the given encodings.
|
||||
|
||||
The various steps are:
|
||||
1. Truncate according to global params (provided to `enable_truncation`)
|
||||
2. Apply the PostProcessor
|
||||
3. Pad according to global params. (provided to `enable_padding`)
|
||||
|
||||
Args:
|
||||
encoding: Encoding:
|
||||
The main Encoding to post process
|
||||
|
||||
pair: Optional[Encoding]:
|
||||
An optional pair Encoding
|
||||
|
||||
add_special_tokens: bool:
|
||||
Whether to add special tokens
|
||||
|
||||
Returns:
|
||||
The resulting Encoding
|
||||
"""
|
||||
pass
|
||||
|
@ -314,3 +314,28 @@ class BaseTokenizer:
|
||||
The name of the tokenizer, to be used in the saved files
|
||||
"""
|
||||
return self._tokenizer.model.save(directory, name=name)
|
||||
|
||||
def post_process(
|
||||
self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True
|
||||
) -> Encoding:
|
||||
""" Apply all the post-processing steps to the given encodings.
|
||||
|
||||
The various steps are:
|
||||
1. Truncate according to global params (provided to `enable_truncation`)
|
||||
2. Apply the PostProcessor
|
||||
3. Pad according to global params. (provided to `enable_padding`)
|
||||
|
||||
Args:
|
||||
encoding: Encoding:
|
||||
The main Encoding to post process
|
||||
|
||||
pair: Optional[Encoding]:
|
||||
An optional pair Encoding
|
||||
|
||||
add_special_tokens: bool:
|
||||
Whether to add special tokens
|
||||
|
||||
Returns:
|
||||
The resulting Encoding
|
||||
"""
|
||||
return self._tokenizer.post_process(encoding, pair, add_special_tokens)
|
||||
|
@ -656,7 +656,7 @@ impl Tokenizer {
|
||||
}
|
||||
|
||||
/// Post processing logic, handling the case where there is no PostProcessor set
|
||||
fn post_process(
|
||||
pub fn post_process(
|
||||
&self,
|
||||
encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
|
Reference in New Issue
Block a user