mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 08:45:38 +00:00
Implement __new__ for PostProcessors
Allows PostProcessors to be instansiated through python class constructor.
This commit is contained in:
@ -8,31 +8,31 @@ pub struct PostProcessor {
|
||||
pub processor: Container<dyn tk::tokenizer::PostProcessor + Sync>,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[pyclass(extends=PostProcessor)]
|
||||
pub struct BertProcessing {}
|
||||
#[pymethods]
|
||||
impl BertProcessing {
|
||||
#[staticmethod]
|
||||
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<PostProcessor> {
|
||||
Ok(PostProcessor {
|
||||
#[new]
|
||||
fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> {
|
||||
Ok(obj.init(PostProcessor {
|
||||
processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new(
|
||||
sep, cls,
|
||||
))),
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[pyclass]
|
||||
#[pyclass(extends=PostProcessor)]
|
||||
pub struct RobertaProcessing {}
|
||||
#[pymethods]
|
||||
impl RobertaProcessing {
|
||||
#[staticmethod]
|
||||
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<PostProcessor> {
|
||||
Ok(PostProcessor {
|
||||
#[new]
|
||||
fn new(obj: &PyRawObject, sep: (String, u32), cls: (String, u32)) -> PyResult<()> {
|
||||
Ok(obj.init(PostProcessor {
|
||||
processor: Container::Owned(Box::new(tk::processors::roberta::RobertaProcessing::new(
|
||||
sep, cls,
|
||||
))),
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
@ -7,20 +7,23 @@ from .base_tokenizer import BaseTokenizer
|
||||
|
||||
from typing import Optional, List, Union
|
||||
|
||||
|
||||
class BertWordPieceTokenizer(BaseTokenizer):
|
||||
""" Bert WordPiece Tokenizer """
|
||||
|
||||
def __init__(self,
|
||||
vocab_file: Optional[str]=None,
|
||||
add_special_tokens: bool=True,
|
||||
unk_token: str="[UNK]",
|
||||
sep_token: str="[SEP]",
|
||||
cls_token: str="[CLS]",
|
||||
clean_text: bool=True,
|
||||
handle_chinese_chars: bool=True,
|
||||
strip_accents: bool=True,
|
||||
lowercase: bool=True,
|
||||
wordpieces_prefix: str="##"):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file: Optional[str] = None,
|
||||
add_special_tokens: bool = True,
|
||||
unk_token: str = "[UNK]",
|
||||
sep_token: str = "[SEP]",
|
||||
cls_token: str = "[CLS]",
|
||||
clean_text: bool = True,
|
||||
handle_chinese_chars: bool = True,
|
||||
strip_accents: bool = True,
|
||||
lowercase: bool = True,
|
||||
wordpieces_prefix: str = "##",
|
||||
):
|
||||
|
||||
if vocab_file is not None:
|
||||
tokenizer = Tokenizer(WordPiece.from_files(vocab_file, unk_token=unk_token))
|
||||
@ -44,9 +47,8 @@ class BertWordPieceTokenizer(BaseTokenizer):
|
||||
if cls_token_id is None:
|
||||
raise TypeError("cls_token not found in the vocabulary")
|
||||
|
||||
tokenizer.post_processor = BertProcessing.new(
|
||||
(sep_token, sep_token_id),
|
||||
(cls_token, cls_token_id)
|
||||
tokenizer.post_processor = BertProcessing(
|
||||
(sep_token, sep_token_id), (cls_token, cls_token_id)
|
||||
)
|
||||
tokenizer.decoders = decoders.WordPiece(prefix=wordpieces_prefix)
|
||||
|
||||
|
@ -16,8 +16,7 @@ class BertProcessing:
|
||||
- a CLS token
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor:
|
||||
def __init__(self, sep: Tuple[str, int], cls: Tuple[str, int]) -> None:
|
||||
""" Instantiate a new BertProcessing with the given tokens
|
||||
|
||||
Args:
|
||||
@ -32,7 +31,6 @@ class BertProcessing:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RobertaProcessing:
|
||||
""" RobertaProcessing
|
||||
|
||||
@ -42,8 +40,7 @@ class RobertaProcessing:
|
||||
- a CLS token
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def new(sep: Tuple[str, int], cls: Tuple[str, int]) -> PostProcessor:
|
||||
def __init__(self, sep: Tuple[str, int], cls: Tuple[str, int]) -> None:
|
||||
""" Instantiate a new RobertaProcessing with the given tokens
|
||||
|
||||
Args:
|
||||
|
Reference in New Issue
Block a user