Implement __new__ on Decoders

Allow decoders to be initialized from python using the class
constructor.
This commit is contained in:
Bjarte Johansen
2020-02-07 11:03:39 +01:00
parent 4971e9608d
commit 03508826cb
6 changed files with 30 additions and 33 deletions

View File

@@ -26,25 +26,25 @@ impl Decoder {
} }
} }
#[pyclass] #[pyclass(extends=Decoder)]
pub struct ByteLevel {} pub struct ByteLevel {}
#[pymethods] #[pymethods]
impl ByteLevel { impl ByteLevel {
#[staticmethod] #[new]
fn new() -> PyResult<Decoder> { fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(Decoder { Ok(obj.init(Decoder {
decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel::new(false))), decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel::new(false))),
}) }))
} }
} }
#[pyclass] #[pyclass(extends=Decoder)]
pub struct WordPiece {} pub struct WordPiece {}
#[pymethods] #[pymethods]
impl WordPiece { impl WordPiece {
#[staticmethod] #[new]
#[args(kwargs="**")] #[args(kwargs="**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<Decoder> { fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut prefix = String::from("##"); let mut prefix = String::from("##");
if let Some(kwargs) = kwargs { if let Some(kwargs) = kwargs {
@@ -53,19 +53,19 @@ impl WordPiece {
} }
} }
Ok(Decoder { Ok(obj.init(Decoder {
decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new(prefix))), decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new(prefix))),
}) }))
} }
} }
#[pyclass] #[pyclass(extends=Decoder)]
pub struct Metaspace {} pub struct Metaspace {}
#[pymethods] #[pymethods]
impl Metaspace { impl Metaspace {
#[staticmethod] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<Decoder> { fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut replacement = '▁'; let mut replacement = '▁';
let mut add_prefix_space = true; let mut add_prefix_space = true;
@@ -85,22 +85,22 @@ impl Metaspace {
} }
} }
Ok(Decoder { Ok(obj.init(Decoder {
decoder: Container::Owned(Box::new(tk::decoders::metaspace::Metaspace::new( decoder: Container::Owned(Box::new(tk::decoders::metaspace::Metaspace::new(
replacement, replacement,
add_prefix_space, add_prefix_space,
))), ))),
}) }))
} }
} }
#[pyclass] #[pyclass(extends=Decoder)]
pub struct BPEDecoder {} pub struct BPEDecoder {}
#[pymethods] #[pymethods]
impl BPEDecoder { impl BPEDecoder {
#[staticmethod] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<Decoder> { fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut suffix = String::from("</w"); let mut suffix = String::from("</w");
if let Some(kwargs) = kwargs { if let Some(kwargs) = kwargs {
@@ -113,9 +113,9 @@ impl BPEDecoder {
} }
} }
Ok(Decoder { Ok(obj.init(Decoder {
decoder: Container::Owned(Box::new(tk::decoders::bpe::BPEDecoder::new(suffix))), decoder: Container::Owned(Box::new(tk::decoders::bpe::BPEDecoder::new(suffix))),
}) }))
} }
} }

View File

@@ -14,8 +14,7 @@ class Decoder:
class ByteLevel: class ByteLevel:
""" ByteLevel Decoder """ """ ByteLevel Decoder """
@staticmethod def __init__(self) -> None:
def new() -> Decoder:
""" Instantiate a new ByteLevel Decoder """ """ Instantiate a new ByteLevel Decoder """
pass pass
@@ -23,7 +22,7 @@ class WordPiece:
""" WordPiece Decoder """ """ WordPiece Decoder """
@staticmethod @staticmethod
def new(prefix: str="##") -> Decoder: def __init__(self, prefix: str = "##") -> Decoder:
""" Instantiate a new WordPiece Decoder """ Instantiate a new WordPiece Decoder
Args: Args:
@@ -35,9 +34,7 @@ class WordPiece:
class Metaspace: class Metaspace:
""" Metaspace decoder """ """ Metaspace decoder """
@staticmethod def __init__(self, replacement: str = "", add_prefix_space: bool = True) -> None:
def new(replacement: str="",
add_prefix_space: bool=True) -> Decoder:
""" Instantiate a new Metaspace """ Instantiate a new Metaspace
Args: Args:
@@ -54,8 +51,7 @@ class Metaspace:
class BPEDecoder: class BPEDecoder:
""" BPEDecoder """ """ BPEDecoder """
@staticmethod def __init__(self, suffix: str = "</w>") -> None:
def new(suffix: str="</w>") -> Decoder:
""" Instantiate a new BPEDecoder """ Instantiate a new BPEDecoder
Args: Args:

View File

@@ -48,7 +48,7 @@ class BertWordPieceTokenizer(BaseTokenizer):
(sep_token, sep_token_id), (sep_token, sep_token_id),
(cls_token, cls_token_id) (cls_token, cls_token_id)
) )
tokenizer.decoders = decoders.WordPiece.new(prefix=wordpieces_prefix) tokenizer.decoders = decoders.WordPiece(prefix=wordpieces_prefix)
parameters = { parameters = {
"model": "BertWordPiece", "model": "BertWordPiece",

View File

@@ -46,7 +46,7 @@ class ByteLevelBPETokenizer(BaseTokenizer):
tokenizer.normalizer = normalizers[0] tokenizer.normalizer = normalizers[0]
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel.new(add_prefix_space=add_prefix_space) tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel.new(add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.ByteLevel.new() tokenizer.decoder = decoders.ByteLevel()
parameters = { parameters = {
"model": "ByteLevelBPE", "model": "ByteLevelBPE",

View File

@@ -52,7 +52,7 @@ class CharBPETokenizer(BaseTokenizer):
tokenizer.normalizer = normalizers[0] tokenizer.normalizer = normalizers[0]
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit.new() tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit.new()
tokenizer.decoder = decoders.BPEDecoder.new(suffix=suffix) tokenizer.decoder = decoders.BPEDecoder(suffix=suffix)
parameters = { parameters = {
"model": "BPE", "model": "BPE",

View File

@@ -31,8 +31,9 @@ class SentencePieceBPETokenizer(BaseTokenizer):
tokenizer.normalizer = NFKC() tokenizer.normalizer = NFKC()
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace.new(replacement=replacement, tokenizer.pre_tokenizer = pre_tokenizers.Metaspace.new(replacement=replacement,
add_prefix_space=add_prefix_space) add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.Metaspace.new(replacement=replacement, tokenizer.decoder = decoders.Metaspace(
add_prefix_space=add_prefix_space) replacement=replacement, add_prefix_space=add_prefix_space
)
parameters = { parameters = {
"model": "SentencePieceBPE", "model": "SentencePieceBPE",