Implement __new__ for PostProcessors

Allows PostProcessors to be instansiated through python class constructor.
This commit is contained in:
Bjarte Johansen
2020-02-07 11:18:41 +01:00
parent 03508826cb
commit f32e0c09fc
3 changed files with 30 additions and 31 deletions

View File

@ -8,31 +8,31 @@ pub struct PostProcessor {
pub processor: Container<dyn tk::tokenizer::PostProcessor + Sync>,
}
#[pyclass]
#[pyclass(extends=PostProcessor)]
pub struct BertProcessing {}
#[pymethods]
impl BertProcessing {
#[staticmethod]
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<PostProcessor> {
Ok(PostProcessor {
#[new]
fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> {
Ok(obj.init(PostProcessor {
processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new(
sep, cls,
))),
})
}))
}
}
#[pyclass]
#[pyclass(extends=PostProcessor)]
pub struct RobertaProcessing {}
#[pymethods]
impl RobertaProcessing {
#[staticmethod]
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<PostProcessor> {
Ok(PostProcessor {
#[new]
fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> {
Ok(obj.init(PostProcessor {
processor: Container::Owned(Box::new(tk::processors::roberta::RobertaProcessing::new(
sep, cls,
))),
})
}))
}
}

View File

@ -7,20 +7,23 @@ from .base_tokenizer import BaseTokenizer
from typing import Optional, List, Union
class BertWordPieceTokenizer(BaseTokenizer):
""" Bert WordPiece Tokenizer """
def __init__(self,
vocab_file: Optional[str]=None,
add_special_tokens: bool=True,
unk_token: str="[UNK]",
sep_token: str="[SEP]",
cls_token: str="[CLS]",
clean_text: bool=True,
handle_chinese_chars: bool=True,
strip_accents: bool=True,
lowercase: bool=True,
wordpieces_prefix: str="##"):
def __init__(
self,
vocab_file: Optional[str] = None,
add_special_tokens: bool = True,
unk_token: str = "[UNK]",
sep_token: str = "[SEP]",
cls_token: str = "[CLS]",
clean_text: bool = True,
handle_chinese_chars: bool = True,
strip_accents: bool = True,
lowercase: bool = True,
wordpieces_prefix: str = "##",
):
if vocab_file is not None:
tokenizer = Tokenizer(WordPiece.from_files(vocab_file, unk_token=unk_token))
@ -44,9 +47,8 @@ class BertWordPieceTokenizer(BaseTokenizer):
if cls_token_id is None:
raise TypeError("cls_token not found in the vocabulary")
tokenizer.post_processor = BertProcessing.new(
(sep_token, sep_token_id),
(cls_token, cls_token_id)
tokenizer.post_processor = BertProcessing(
(sep_token, sep_token_id), (cls_token, cls_token_id)
)
tokenizer.decoders = decoders.WordPiece(prefix=wordpieces_prefix)

View File

@ -16,8 +16,7 @@ class BertProcessing:
- a CLS token
"""
@staticmethod
def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor:
def __init__(self, sep: Tuple[str, int], cls: Tuple[str, int]) -> None:
""" Instantiate a new BertProcessing with the given tokens
Args:
@ -32,7 +31,6 @@ class BertProcessing:
"""
pass
class RobertaProcessing:
""" RobertaProcessing
@ -42,8 +40,7 @@ class RobertaProcessing:
- a CLS token
"""
@staticmethod
def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor:
def __init__(self, sep: Tuple[str, int], cls: Tuple[str, int]) -> None:
""" Instantiate a new RobertaProcessing with the given tokens
Args: