Expose post_process on the Tokenizer

This commit is contained in:
Anthony MOI
2020-03-25 17:54:48 -04:00
parent 9ce895550b
commit 9bd9e0b3c1
6 changed files with 126 additions and 2 deletions

View File

@@ -9,7 +9,7 @@ use tk::tokenizer::{Offsets, PaddingDirection};
#[pyclass(dict)]
#[repr(transparent)]
pub struct Encoding {
encoding: tk::tokenizer::Encoding,
pub encoding: tk::tokenizer::Encoding,
}
impl Encoding {

View File

@@ -343,6 +343,25 @@ impl Tokenizer {
})
}
#[args(pair = "None", add_special_tokens = true)]
fn post_process(
&self,
encoding: &Encoding,
pair: Option<&Encoding>,
add_special_tokens: bool,
) -> PyResult<Encoding> {
ToPyResult(
self.tokenizer
.post_process(
encoding.encoding.clone(),
pair.map(|p| p.encoding.clone()),
add_special_tokens,
)
.map(Encoding::new),
)
.into()
}
#[getter]
fn get_model(&self) -> PyResult<Model> {
Ok(Model {

View File

@@ -33,6 +33,7 @@ class Encoding:
Returns:
The resulting Encoding
"""
pass
@property
def ids(self) -> List[int]:
""" The tokenized ids """
@@ -441,3 +442,27 @@ class Tokenizer:
The number of tokens that were added to the vocabulary
"""
pass
def post_process(
self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True
) -> Encoding:
""" Apply all the post-processing steps to the given encodings.
The various steps are:
1. Truncate according to global params (provided to `enable_truncation`)
2. Apply the PostProcessor
3. Pad according to global params. (provided to `enable_padding`)
Args:
encoding: Encoding:
The main Encoding to post process
pair: Optional[Encoding]:
An optional pair Encoding
add_special_tokens: bool:
Whether to add special tokens
Returns:
The resulting Encoding
"""
pass

View File

@@ -314,3 +314,28 @@ class BaseTokenizer:
The name of the tokenizer, to be used in the saved files
"""
return self._tokenizer.model.save(directory, name=name)
def post_process(
self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True
) -> Encoding:
""" Apply all the post-processing steps to the given encodings.
The various steps are:
1. Truncate according to global params (provided to `enable_truncation`)
2. Apply the PostProcessor
3. Pad according to global params. (provided to `enable_padding`)
Args:
encoding: Encoding:
The main Encoding to post process
pair: Optional[Encoding]:
An optional pair Encoding
add_special_tokens: bool:
Whether to add special tokens
Returns:
The resulting Encoding
"""
return self._tokenizer.post_process(encoding, pair, add_special_tokens)