Implement __new__ on Normalizers

__new__ allows Normalizers to be initialized as normal python
objects. This also means that Normalizers are given the correct class
name.
This commit is contained in:
Bjarte Johansen
2020-02-07 10:47:55 +01:00
parent be67d51185
commit 0e5d81b400
6 changed files with 61 additions and 64 deletions

View File

@ -10,13 +10,13 @@ pub struct Normalizer {
pub normalizer: Container<dyn tk::tokenizer::Normalizer + Sync>,
}
#[pyclass]
#[pyclass(extends=Normalizer)]
pub struct BertNormalizer {}
#[pymethods]
impl BertNormalizer {
#[staticmethod]
#[new]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<Normalizer> {
fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut clean_text = true;
let mut handle_chinese_chars = true;
let mut strip_accents = true;
@ -35,71 +35,71 @@ impl BertNormalizer {
}
}
Ok(Normalizer {
Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::bert::BertNormalizer::new(
clean_text,
handle_chinese_chars,
strip_accents,
lowercase,
))),
})
}))
}
}
#[pyclass]
#[pyclass(extends=Normalizer)]
pub struct NFD {}
#[pymethods]
impl NFD {
#[staticmethod]
fn new() -> PyResult<Normalizer> {
Ok(Normalizer {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFD)),
})
}))
}
}
#[pyclass]
#[pyclass(extends=Normalizer)]
pub struct NFKD {}
#[pymethods]
impl NFKD {
#[staticmethod]
fn new() -> PyResult<Normalizer> {
Ok(Normalizer {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKD)),
})
}))
}
}
#[pyclass]
#[pyclass(extends=Normalizer)]
pub struct NFC {}
#[pymethods]
impl NFC {
#[staticmethod]
fn new() -> PyResult<Normalizer> {
Ok(Normalizer {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFC)),
})
}))
}
}
#[pyclass]
#[pyclass(extends=Normalizer)]
pub struct NFKC {}
#[pymethods]
impl NFKC {
#[staticmethod]
fn new() -> PyResult<Normalizer> {
Ok(Normalizer {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKC)),
})
}))
}
}
#[pyclass]
#[pyclass(extends=Normalizer)]
pub struct Sequence {}
#[pymethods]
impl Sequence {
#[staticmethod]
fn new(normalizers: &PyList) -> PyResult<Normalizer> {
#[new]
fn new(obj: &PyRawObject, normalizers: &PyList) -> PyResult<()> {
let normalizers = normalizers
.iter()
.map(|n| {
@ -114,22 +114,22 @@ impl Sequence {
})
.collect::<PyResult<_>>()?;
Ok(Normalizer {
Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::utils::Sequence::new(
normalizers,
))),
})
}))
}
}
#[pyclass]
#[pyclass(extends=Normalizer)]
pub struct Lowercase {}
#[pymethods]
impl Lowercase {
#[staticmethod]
fn new() -> PyResult<Normalizer> {
Ok(Normalizer {
#[new]
fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::utils::Lowercase)),
})
}))
}
}

View File

@ -28,10 +28,12 @@ class BertWordPieceTokenizer(BaseTokenizer):
tokenizer = Tokenizer(WordPiece.empty())
tokenizer.add_special_tokens([unk_token, sep_token, cls_token])
tokenizer.normalizer = BertNormalizer.new(clean_text=clean_text,
tokenizer.normalizer = BertNormalizer(
clean_text=clean_text,
handle_chinese_chars=handle_chinese_chars,
strip_accents=strip_accents,
lowercase=lowercase)
lowercase=lowercase,
)
tokenizer.pre_tokenizer = BertPreTokenizer.new()
if add_special_tokens and vocab_file is not None:

View File

@ -36,12 +36,12 @@ class ByteLevelBPETokenizer(BaseTokenizer):
normalizers += [unicode_normalizer_from_str(unicode_normalizer)]
if do_lowercase:
normalizers += [Lowercase.new()]
normalizers += [Lowercase()]
# Create the normalizer structure
if len(normalizers) > 0:
if len(normalizers) > 1:
tokenizer.normalizer = Sequence.new(normalizers)
tokenizer.normalizer = Sequence(normalizers)
else:
tokenizer.normalizer = normalizers[0]

View File

@ -42,12 +42,12 @@ class CharBPETokenizer(BaseTokenizer):
normalizers += [unicode_normalizer_from_str(unicode_normalizer)]
if do_lowercase:
normalizers += [Lowercase.new()]
normalizers += [Lowercase()]
# Create the normalizer structure
if len(normalizers) > 0:
if len(normalizers) > 1:
tokenizer.normalizer = Sequence.new(normalizers)
tokenizer.normalizer = Sequence(normalizers)
else:
tokenizer.normalizer = normalizers[0]

View File

@ -28,7 +28,7 @@ class SentencePieceBPETokenizer(BaseTokenizer):
tokenizer.add_special_tokens([ unk_token ])
tokenizer.normalizer = NFKC.new()
tokenizer.normalizer = NFKC()
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace.new(replacement=replacement,
add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.Metaspace.new(replacement=replacement,

View File

@ -14,11 +14,13 @@ class BertNormalizer(Normalizer):
This includes cleaning the text, handling accents, chinese chars and lowercasing
"""
@staticmethod
def new(clean_text: Optional[bool]=True,
def __init__(
self,
clean_text: Optional[bool] = True,
handle_chinese_chars: Optional[bool] = True,
strip_accents: Optional[bool] = True,
lowercase: Optional[bool]=True) -> Normalizer:
lowercase: Optional[bool] = True,
) -> None:
""" Instantiate a BertNormalizer with the given options.
Args:
@ -43,32 +45,28 @@ class BertNormalizer(Normalizer):
class NFD(Normalizer):
""" NFD Unicode Normalizer """
@staticmethod
def new() -> Normalizer:
def __init__(self) -> None:
""" Instantiate a new NFD Normalizer """
pass
class NFKD(Normalizer):
""" NFKD Unicode Normalizer """
@staticmethod
def new() -> Normalizer:
def __init__(self) -> None:
""" Instantiate a new NFKD Normalizer """
pass
class NFC(Normalizer):
""" NFC Unicode Normalizer """
@staticmethod
def new() -> Normalizer:
def __init__(self) -> None:
""" Instantiate a new NFC Normalizer """
pass
class NFKC(Normalizer):
""" NFKC Unicode Normalizer """
@staticmethod
def new() -> Normalizer:
def __init__(self) -> None:
""" Instantiate a new NFKC Normalizer """
pass
@ -78,8 +76,7 @@ class Sequence(Normalizer):
All the normalizers run in sequence in the given order
"""
@staticmethod
def new(normalizers: List[Normalizer]) -> Normalizer:
def __init__(self, normalizers: List[Normalizer]) -> None:
""" Instantiate a new normalization Sequence using the given normalizers
Args:
@ -91,12 +88,10 @@ class Sequence(Normalizer):
class Lowercase(Normalizer):
""" Lowercase Normalizer """
@staticmethod
def new() -> Normalizer:
def __init__(self) -> None:
""" Instantiate a new Lowercase Normalizer """
pass
def unicode_normalizer_from_str(normalizer: str) -> Normalizer:
"""
Instanciate unicode normalizer from the normalizer name