mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-01 23:09:34 +00:00
Expose post_process on the Tokenizer
This commit is contained in:
@ -2,6 +2,7 @@ extern crate tokenizers as tk;
|
|||||||
|
|
||||||
use crate::container::Container;
|
use crate::container::Container;
|
||||||
use crate::decoders::JsDecoder;
|
use crate::decoders::JsDecoder;
|
||||||
|
use crate::encoding::JsEncoding;
|
||||||
use crate::models::JsModel;
|
use crate::models::JsModel;
|
||||||
use crate::normalizers::JsNormalizer;
|
use crate::normalizers::JsNormalizer;
|
||||||
use crate::pre_tokenizers::JsPreTokenizer;
|
use crate::pre_tokenizers::JsPreTokenizer;
|
||||||
@ -712,6 +713,60 @@ declare_types! {
|
|||||||
Ok(cx.undefined().upcast())
|
Ok(cx.undefined().upcast())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
method postProcess(mut cx) {
|
||||||
|
// postProcess(
|
||||||
|
// encoding: Encoding,
|
||||||
|
// pair?: Encoding,
|
||||||
|
// addSpecialTokens: boolean = true
|
||||||
|
// ): Encoding
|
||||||
|
|
||||||
|
let encoding = {
|
||||||
|
let encoding = cx.argument::<JsEncoding>(0)?;
|
||||||
|
let guard = cx.lock();
|
||||||
|
let encoding = encoding
|
||||||
|
.borrow(&guard)
|
||||||
|
.encoding
|
||||||
|
.execute(|e| *e.unwrap().clone());
|
||||||
|
encoding
|
||||||
|
};
|
||||||
|
let pair = cx.argument_opt(1).map(|item| {
|
||||||
|
if item.downcast::<JsUndefined>().is_ok() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
item.downcast::<JsEncoding>().map(|e| {
|
||||||
|
let guard = cx.lock();
|
||||||
|
let encoding = e.borrow(&guard)
|
||||||
|
.encoding
|
||||||
|
.execute(|e| *e.unwrap().clone());
|
||||||
|
encoding
|
||||||
|
}).ok()
|
||||||
|
}
|
||||||
|
}).flatten();
|
||||||
|
let add_special_tokens = cx
|
||||||
|
.argument_opt(2)
|
||||||
|
.map(|arg| Ok(arg.downcast::<JsBoolean>().or_throw(&mut cx)?.value()))
|
||||||
|
.unwrap_or(Ok(true))?;
|
||||||
|
|
||||||
|
let encoding = {
|
||||||
|
let this = cx.this();
|
||||||
|
let guard = cx.lock();
|
||||||
|
let encoding = this.borrow(&guard)
|
||||||
|
.tokenizer.post_process(encoding, pair, add_special_tokens);
|
||||||
|
encoding
|
||||||
|
};
|
||||||
|
let encoding = encoding
|
||||||
|
.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?;
|
||||||
|
|
||||||
|
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
|
||||||
|
let guard = cx.lock();
|
||||||
|
js_encoding
|
||||||
|
.borrow_mut(&guard)
|
||||||
|
.encoding
|
||||||
|
.to_owned(Box::new(encoding));
|
||||||
|
|
||||||
|
Ok(js_encoding.upcast())
|
||||||
|
}
|
||||||
|
|
||||||
method getModel(mut cx) {
|
method getModel(mut cx) {
|
||||||
// getModel(): Model
|
// getModel(): Model
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ use tk::tokenizer::{Offsets, PaddingDirection};
|
|||||||
#[pyclass(dict)]
|
#[pyclass(dict)]
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
pub struct Encoding {
|
pub struct Encoding {
|
||||||
encoding: tk::tokenizer::Encoding,
|
pub encoding: tk::tokenizer::Encoding,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Encoding {
|
impl Encoding {
|
||||||
|
@ -343,6 +343,25 @@ impl Tokenizer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[args(pair = "None", add_special_tokens = true)]
|
||||||
|
fn post_process(
|
||||||
|
&self,
|
||||||
|
encoding: &Encoding,
|
||||||
|
pair: Option<&Encoding>,
|
||||||
|
add_special_tokens: bool,
|
||||||
|
) -> PyResult<Encoding> {
|
||||||
|
ToPyResult(
|
||||||
|
self.tokenizer
|
||||||
|
.post_process(
|
||||||
|
encoding.encoding.clone(),
|
||||||
|
pair.map(|p| p.encoding.clone()),
|
||||||
|
add_special_tokens,
|
||||||
|
)
|
||||||
|
.map(Encoding::new),
|
||||||
|
)
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_model(&self) -> PyResult<Model> {
|
fn get_model(&self) -> PyResult<Model> {
|
||||||
Ok(Model {
|
Ok(Model {
|
||||||
|
@ -33,6 +33,7 @@ class Encoding:
|
|||||||
Returns:
|
Returns:
|
||||||
The resulting Encoding
|
The resulting Encoding
|
||||||
"""
|
"""
|
||||||
|
pass
|
||||||
@property
|
@property
|
||||||
def ids(self) -> List[int]:
|
def ids(self) -> List[int]:
|
||||||
""" The tokenized ids """
|
""" The tokenized ids """
|
||||||
@ -441,3 +442,27 @@ class Tokenizer:
|
|||||||
The number of tokens that were added to the vocabulary
|
The number of tokens that were added to the vocabulary
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
def post_process(
|
||||||
|
self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True
|
||||||
|
) -> Encoding:
|
||||||
|
""" Apply all the post-processing steps to the given encodings.
|
||||||
|
|
||||||
|
The various steps are:
|
||||||
|
1. Truncate according to global params (provided to `enable_truncation`)
|
||||||
|
2. Apply the PostProcessor
|
||||||
|
3. Pad according to global params. (provided to `enable_padding`)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoding: Encoding:
|
||||||
|
The main Encoding to post process
|
||||||
|
|
||||||
|
pair: Optional[Encoding]:
|
||||||
|
An optional pair Encoding
|
||||||
|
|
||||||
|
add_special_tokens: bool:
|
||||||
|
Whether to add special tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The resulting Encoding
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
@ -314,3 +314,28 @@ class BaseTokenizer:
|
|||||||
The name of the tokenizer, to be used in the saved files
|
The name of the tokenizer, to be used in the saved files
|
||||||
"""
|
"""
|
||||||
return self._tokenizer.model.save(directory, name=name)
|
return self._tokenizer.model.save(directory, name=name)
|
||||||
|
|
||||||
|
def post_process(
|
||||||
|
self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True
|
||||||
|
) -> Encoding:
|
||||||
|
""" Apply all the post-processing steps to the given encodings.
|
||||||
|
|
||||||
|
The various steps are:
|
||||||
|
1. Truncate according to global params (provided to `enable_truncation`)
|
||||||
|
2. Apply the PostProcessor
|
||||||
|
3. Pad according to global params. (provided to `enable_padding`)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoding: Encoding:
|
||||||
|
The main Encoding to post process
|
||||||
|
|
||||||
|
pair: Optional[Encoding]:
|
||||||
|
An optional pair Encoding
|
||||||
|
|
||||||
|
add_special_tokens: bool:
|
||||||
|
Whether to add special tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The resulting Encoding
|
||||||
|
"""
|
||||||
|
return self._tokenizer.post_process(encoding, pair, add_special_tokens)
|
||||||
|
@ -656,7 +656,7 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Post processing logic, handling the case where there is no PostProcessor set
|
/// Post processing logic, handling the case where there is no PostProcessor set
|
||||||
fn post_process(
|
pub fn post_process(
|
||||||
&self,
|
&self,
|
||||||
encoding: Encoding,
|
encoding: Encoding,
|
||||||
pair_encoding: Option<Encoding>,
|
pair_encoding: Option<Encoding>,
|
||||||
|
Reference in New Issue
Block a user