Implement __new__ for PostProcessors

Allows PostProcessors to be instansiated through python class constructor.
This commit is contained in:
Bjarte Johansen
2020-02-07 11:18:41 +01:00
parent 03508826cb
commit f32e0c09fc
3 changed files with 30 additions and 31 deletions

View File

@ -8,31 +8,31 @@ pub struct PostProcessor {
pub processor: Container<dyn tk::tokenizer::PostProcessor + Sync>, pub processor: Container<dyn tk::tokenizer::PostProcessor + Sync>,
} }
#[pyclass] #[pyclass(extends=PostProcessor)]
pub struct BertProcessing {} pub struct BertProcessing {}
#[pymethods] #[pymethods]
impl BertProcessing { impl BertProcessing {
#[staticmethod] #[new]
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<PostProcessor> { fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> {
Ok(PostProcessor { Ok(obj.init(PostProcessor {
processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new( processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new(
sep, cls, sep, cls,
))), ))),
}) }))
} }
} }
#[pyclass] #[pyclass(extends=PostProcessor)]
pub struct RobertaProcessing {} pub struct RobertaProcessing {}
#[pymethods] #[pymethods]
impl RobertaProcessing { impl RobertaProcessing {
#[staticmethod] #[new]
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<PostProcessor> { fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> {
Ok(PostProcessor { Ok(obj.init(PostProcessor {
processor: Container::Owned(Box::new(tk::processors::roberta::RobertaProcessing::new( processor: Container::Owned(Box::new(tk::processors::roberta::RobertaProcessing::new(
sep, cls, sep, cls,
))), ))),
}) }))
} }
} }

View File

@ -7,20 +7,23 @@ from .base_tokenizer import BaseTokenizer
from typing import Optional, List, Union from typing import Optional, List, Union
class BertWordPieceTokenizer(BaseTokenizer): class BertWordPieceTokenizer(BaseTokenizer):
""" Bert WordPiece Tokenizer """ """ Bert WordPiece Tokenizer """
def __init__(self, def __init__(
vocab_file: Optional[str]=None, self,
add_special_tokens: bool=True, vocab_file: Optional[str] = None,
unk_token: str="[UNK]", add_special_tokens: bool = True,
sep_token: str="[SEP]", unk_token: str = "[UNK]",
cls_token: str="[CLS]", sep_token: str = "[SEP]",
clean_text: bool=True, cls_token: str = "[CLS]",
handle_chinese_chars: bool=True, clean_text: bool = True,
strip_accents: bool=True, handle_chinese_chars: bool = True,
lowercase: bool=True, strip_accents: bool = True,
wordpieces_prefix: str="##"): lowercase: bool = True,
wordpieces_prefix: str = "##",
):
if vocab_file is not None: if vocab_file is not None:
tokenizer = Tokenizer(WordPiece.from_files(vocab_file, unk_token=unk_token)) tokenizer = Tokenizer(WordPiece.from_files(vocab_file, unk_token=unk_token))
@ -44,9 +47,8 @@ class BertWordPieceTokenizer(BaseTokenizer):
if cls_token_id is None: if cls_token_id is None:
raise TypeError("cls_token not found in the vocabulary") raise TypeError("cls_token not found in the vocabulary")
tokenizer.post_processor = BertProcessing.new( tokenizer.post_processor = BertProcessing(
(sep_token, sep_token_id), (sep_token, sep_token_id), (cls_token, cls_token_id)
(cls_token, cls_token_id)
) )
tokenizer.decoders = decoders.WordPiece(prefix=wordpieces_prefix) tokenizer.decoders = decoders.WordPiece(prefix=wordpieces_prefix)

View File

@ -16,8 +16,7 @@ class BertProcessing:
- a CLS token - a CLS token
""" """
@staticmethod def __init__(self, sep: Tuple[str, int], cls: Tuple[str, int]) -> None:
def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor:
""" Instantiate a new BertProcessing with the given tokens """ Instantiate a new BertProcessing with the given tokens
Args: Args:
@ -32,7 +31,6 @@ class BertProcessing:
""" """
pass pass
class RobertaProcessing: class RobertaProcessing:
""" RobertaProcessing """ RobertaProcessing
@ -42,8 +40,7 @@ class RobertaProcessing:
- a CLS token - a CLS token
""" """
@staticmethod def __init__(self, sep: Tuple[str, int], cls: Tuple[str, int]) -> None:
def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor:
""" Instantiate a new RobertaProcessing with the given tokens """ Instantiate a new RobertaProcessing with the given tokens
Args: Args: