diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index e77b78fb..3d17f679 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -8,31 +8,31 @@ pub struct PostProcessor { pub processor: Container, } -#[pyclass] +#[pyclass(extends=PostProcessor)] pub struct BertProcessing {} #[pymethods] impl BertProcessing { - #[staticmethod] - fn new(sep: (String, u32), cls: (String, u32)) -> PyResult { - Ok(PostProcessor { + #[new] + fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> { + Ok(obj.init(PostProcessor { processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new( sep, cls, ))), - }) + })) } } -#[pyclass] +#[pyclass(extends=PostProcessor)] pub struct RobertaProcessing {} #[pymethods] impl RobertaProcessing { - #[staticmethod] - fn new(sep: (String, u32), cls: (String, u32)) -> PyResult { - Ok(PostProcessor { + #[new] + fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> { + Ok(obj.init(PostProcessor { processor: Container::Owned(Box::new(tk::processors::roberta::RobertaProcessing::new( sep, cls, ))), - }) + })) } -} \ No newline at end of file +} diff --git a/bindings/python/tokenizers/implementations/bert_wordpiece.py b/bindings/python/tokenizers/implementations/bert_wordpiece.py index dfc7c66d..6920bd27 100644 --- a/bindings/python/tokenizers/implementations/bert_wordpiece.py +++ b/bindings/python/tokenizers/implementations/bert_wordpiece.py @@ -7,20 +7,23 @@ from .base_tokenizer import BaseTokenizer from typing import Optional, List, Union + class BertWordPieceTokenizer(BaseTokenizer): """ Bert WordPiece Tokenizer """ - def __init__(self, - vocab_file: Optional[str]=None, - add_special_tokens: bool=True, - unk_token: str="[UNK]", - sep_token: str="[SEP]", - cls_token: str="[CLS]", - clean_text: bool=True, - handle_chinese_chars: bool=True, - strip_accents: bool=True, - lowercase: bool=True, - wordpieces_prefix: str="##"): + def __init__( + self, + vocab_file: Optional[str] = None, + add_special_tokens: bool = True, + unk_token: str = "[UNK]", + sep_token: str = "[SEP]", + cls_token: str = "[CLS]", + clean_text: bool = True, + handle_chinese_chars: bool = True, + strip_accents: bool = True, + lowercase: bool = True, + wordpieces_prefix: str = "##", + ): if vocab_file is not None: tokenizer = Tokenizer(WordPiece.from_files(vocab_file, unk_token=unk_token)) @@ -44,9 +47,8 @@ class BertWordPieceTokenizer(BaseTokenizer): if cls_token_id is None: raise TypeError("cls_token not found in the vocabulary") - tokenizer.post_processor = BertProcessing.new( - (sep_token, sep_token_id), - (cls_token, cls_token_id) + tokenizer.post_processor = BertProcessing( + (sep_token, sep_token_id), (cls_token, cls_token_id) ) tokenizer.decoders = decoders.WordPiece(prefix=wordpieces_prefix) diff --git a/bindings/python/tokenizers/processors/__init__.pyi b/bindings/python/tokenizers/processors/__init__.pyi index 8508a122..4c037eae 100644 --- a/bindings/python/tokenizers/processors/__init__.pyi +++ b/bindings/python/tokenizers/processors/__init__.pyi @@ -16,8 +16,7 @@ class BertProcessing: - a CLS token """ - @staticmethod - def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor: + def __init__(self, sep: Tuple[str, int], cls: Tuple[str, int]) -> None: """ Instantiate a new BertProcessing with the given tokens Args: @@ -32,7 +31,6 @@ class BertProcessing: """ pass - class RobertaProcessing: """ RobertaProcessing @@ -42,8 +40,7 @@ class RobertaProcessing: - a CLS token """ - @staticmethod - def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor: + def __init__(self, sep: Tuple[str, int], cls: Tuple[str, int]) -> None: """ Instantiate a new RobertaProcessing with the given tokens Args: @@ -56,4 +53,4 @@ class RobertaProcessing: Returns: PostProcessor """ - pass \ No newline at end of file + pass