mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-06 20:58:22 +00:00
Expose post_process on the Tokenizer
This commit is contained in:
@@ -9,7 +9,7 @@ use tk::tokenizer::{Offsets, PaddingDirection};
|
||||
#[pyclass(dict)]
|
||||
#[repr(transparent)]
|
||||
pub struct Encoding {
|
||||
encoding: tk::tokenizer::Encoding,
|
||||
pub encoding: tk::tokenizer::Encoding,
|
||||
}
|
||||
|
||||
impl Encoding {
|
||||
|
||||
@@ -343,6 +343,25 @@ impl Tokenizer {
|
||||
})
|
||||
}
|
||||
|
||||
#[args(pair = "None", add_special_tokens = true)]
|
||||
fn post_process(
|
||||
&self,
|
||||
encoding: &Encoding,
|
||||
pair: Option<&Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> PyResult<Encoding> {
|
||||
ToPyResult(
|
||||
self.tokenizer
|
||||
.post_process(
|
||||
encoding.encoding.clone(),
|
||||
pair.map(|p| p.encoding.clone()),
|
||||
add_special_tokens,
|
||||
)
|
||||
.map(Encoding::new),
|
||||
)
|
||||
.into()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_model(&self) -> PyResult<Model> {
|
||||
Ok(Model {
|
||||
|
||||
@@ -33,6 +33,7 @@ class Encoding:
|
||||
Returns:
|
||||
The resulting Encoding
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def ids(self) -> List[int]:
|
||||
""" The tokenized ids """
|
||||
@@ -441,3 +442,27 @@ class Tokenizer:
|
||||
The number of tokens that were added to the vocabulary
|
||||
"""
|
||||
pass
|
||||
def post_process(
|
||||
self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True
|
||||
) -> Encoding:
|
||||
""" Apply all the post-processing steps to the given encodings.
|
||||
|
||||
The various steps are:
|
||||
1. Truncate according to global params (provided to `enable_truncation`)
|
||||
2. Apply the PostProcessor
|
||||
3. Pad according to global params. (provided to `enable_padding`)
|
||||
|
||||
Args:
|
||||
encoding: Encoding:
|
||||
The main Encoding to post process
|
||||
|
||||
pair: Optional[Encoding]:
|
||||
An optional pair Encoding
|
||||
|
||||
add_special_tokens: bool:
|
||||
Whether to add special tokens
|
||||
|
||||
Returns:
|
||||
The resulting Encoding
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -314,3 +314,28 @@ class BaseTokenizer:
|
||||
The name of the tokenizer, to be used in the saved files
|
||||
"""
|
||||
return self._tokenizer.model.save(directory, name=name)
|
||||
|
||||
def post_process(
|
||||
self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True
|
||||
) -> Encoding:
|
||||
""" Apply all the post-processing steps to the given encodings.
|
||||
|
||||
The various steps are:
|
||||
1. Truncate according to global params (provided to `enable_truncation`)
|
||||
2. Apply the PostProcessor
|
||||
3. Pad according to global params. (provided to `enable_padding`)
|
||||
|
||||
Args:
|
||||
encoding: Encoding:
|
||||
The main Encoding to post process
|
||||
|
||||
pair: Optional[Encoding]:
|
||||
An optional pair Encoding
|
||||
|
||||
add_special_tokens: bool:
|
||||
Whether to add special tokens
|
||||
|
||||
Returns:
|
||||
The resulting Encoding
|
||||
"""
|
||||
return self._tokenizer.post_process(encoding, pair, add_special_tokens)
|
||||
|
||||
Reference in New Issue
Block a user