Implement __new__ on Normalizers

__new__ allows Normalizers to be initialized as normal python
objects. This also means that Normalizers are given the correct class
name.
This commit is contained in:
Bjarte Johansen
2020-02-07 10:47:55 +01:00
parent be67d51185
commit 0e5d81b400
6 changed files with 61 additions and 64 deletions

View File

@ -10,13 +10,13 @@ pub struct Normalizer {
pub normalizer: Container<dyn tk::tokenizer::Normalizer + Sync>, pub normalizer: Container<dyn tk::tokenizer::Normalizer + Sync>,
} }
#[pyclass] #[pyclass(extends=Normalizer)]
pub struct BertNormalizer {} pub struct BertNormalizer {}
#[pymethods] #[pymethods]
impl BertNormalizer { impl BertNormalizer {
#[staticmethod] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<Normalizer> { fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut clean_text = true; let mut clean_text = true;
let mut handle_chinese_chars = true; let mut handle_chinese_chars = true;
let mut strip_accents = true; let mut strip_accents = true;
@ -35,71 +35,71 @@ impl BertNormalizer {
} }
} }
Ok(Normalizer { Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::bert::BertNormalizer::new( normalizer: Container::Owned(Box::new(tk::normalizers::bert::BertNormalizer::new(
clean_text, clean_text,
handle_chinese_chars, handle_chinese_chars,
strip_accents, strip_accents,
lowercase, lowercase,
))), ))),
}) }))
} }
} }
#[pyclass] #[pyclass(extends=Normalizer)]
pub struct NFD {} pub struct NFD {}
#[pymethods] #[pymethods]
impl NFD { impl NFD {
#[staticmethod] #[new]
fn new() -> PyResult<Normalizer> { fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(Normalizer { Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFD)), normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFD)),
}) }))
} }
} }
#[pyclass] #[pyclass(extends=Normalizer)]
pub struct NFKD {} pub struct NFKD {}
#[pymethods] #[pymethods]
impl NFKD { impl NFKD {
#[staticmethod] #[new]
fn new() -> PyResult<Normalizer> { fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(Normalizer { Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKD)), normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKD)),
}) }))
} }
} }
#[pyclass] #[pyclass(extends=Normalizer)]
pub struct NFC {} pub struct NFC {}
#[pymethods] #[pymethods]
impl NFC { impl NFC {
#[staticmethod] #[new]
fn new() -> PyResult<Normalizer> { fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(Normalizer { Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFC)), normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFC)),
}) }))
} }
} }
#[pyclass] #[pyclass(extends=Normalizer)]
pub struct NFKC {} pub struct NFKC {}
#[pymethods] #[pymethods]
impl NFKC { impl NFKC {
#[staticmethod] #[new]
fn new() -> PyResult<Normalizer> { fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(Normalizer { Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKC)), normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKC)),
}) }))
} }
} }
#[pyclass] #[pyclass(extends=Normalizer)]
pub struct Sequence {} pub struct Sequence {}
#[pymethods] #[pymethods]
impl Sequence { impl Sequence {
#[staticmethod] #[new]
fn new(normalizers: &PyList) -> PyResult<Normalizer> { fn new(obj: &PyRawObject, normalizers: &PyList) -> PyResult<()> {
let normalizers = normalizers let normalizers = normalizers
.iter() .iter()
.map(|n| { .map(|n| {
@ -114,22 +114,22 @@ impl Sequence {
}) })
.collect::<PyResult<_>>()?; .collect::<PyResult<_>>()?;
Ok(Normalizer { Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::utils::Sequence::new( normalizer: Container::Owned(Box::new(tk::normalizers::utils::Sequence::new(
normalizers, normalizers,
))), ))),
}) }))
} }
} }
#[pyclass] #[pyclass(extends=Normalizer)]
pub struct Lowercase {} pub struct Lowercase {}
#[pymethods] #[pymethods]
impl Lowercase { impl Lowercase {
#[staticmethod] #[new]
fn new() -> PyResult<Normalizer> { fn new(obj: &PyRawObject) -> PyResult<()> {
Ok(Normalizer { Ok(obj.init(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::utils::Lowercase)), normalizer: Container::Owned(Box::new(tk::normalizers::utils::Lowercase)),
}) }))
} }
} }

View File

@ -27,11 +27,13 @@ class BertWordPieceTokenizer(BaseTokenizer):
else: else:
tokenizer = Tokenizer(WordPiece.empty()) tokenizer = Tokenizer(WordPiece.empty())
tokenizer.add_special_tokens([ unk_token, sep_token, cls_token ]) tokenizer.add_special_tokens([unk_token, sep_token, cls_token])
tokenizer.normalizer = BertNormalizer.new(clean_text=clean_text, tokenizer.normalizer = BertNormalizer(
handle_chinese_chars=handle_chinese_chars, clean_text=clean_text,
strip_accents=strip_accents, handle_chinese_chars=handle_chinese_chars,
lowercase=lowercase) strip_accents=strip_accents,
lowercase=lowercase,
)
tokenizer.pre_tokenizer = BertPreTokenizer.new() tokenizer.pre_tokenizer = BertPreTokenizer.new()
if add_special_tokens and vocab_file is not None: if add_special_tokens and vocab_file is not None:

View File

@ -36,12 +36,12 @@ class ByteLevelBPETokenizer(BaseTokenizer):
normalizers += [unicode_normalizer_from_str(unicode_normalizer)] normalizers += [unicode_normalizer_from_str(unicode_normalizer)]
if do_lowercase: if do_lowercase:
normalizers += [Lowercase.new()] normalizers += [Lowercase()]
# Create the normalizer structure # Create the normalizer structure
if len(normalizers) > 0: if len(normalizers) > 0:
if len(normalizers) > 1: if len(normalizers) > 1:
tokenizer.normalizer = Sequence.new(normalizers) tokenizer.normalizer = Sequence(normalizers)
else: else:
tokenizer.normalizer = normalizers[0] tokenizer.normalizer = normalizers[0]

View File

@ -42,12 +42,12 @@ class CharBPETokenizer(BaseTokenizer):
normalizers += [unicode_normalizer_from_str(unicode_normalizer)] normalizers += [unicode_normalizer_from_str(unicode_normalizer)]
if do_lowercase: if do_lowercase:
normalizers += [Lowercase.new()] normalizers += [Lowercase()]
# Create the normalizer structure # Create the normalizer structure
if len(normalizers) > 0: if len(normalizers) > 0:
if len(normalizers) > 1: if len(normalizers) > 1:
tokenizer.normalizer = Sequence.new(normalizers) tokenizer.normalizer = Sequence(normalizers)
else: else:
tokenizer.normalizer = normalizers[0] tokenizer.normalizer = normalizers[0]

View File

@ -28,7 +28,7 @@ class SentencePieceBPETokenizer(BaseTokenizer):
tokenizer.add_special_tokens([ unk_token ]) tokenizer.add_special_tokens([ unk_token ])
tokenizer.normalizer = NFKC.new() tokenizer.normalizer = NFKC()
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace.new(replacement=replacement, tokenizer.pre_tokenizer = pre_tokenizers.Metaspace.new(replacement=replacement,
add_prefix_space=add_prefix_space) add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.Metaspace.new(replacement=replacement, tokenizer.decoder = decoders.Metaspace.new(replacement=replacement,

View File

@ -14,11 +14,13 @@ class BertNormalizer(Normalizer):
This includes cleaning the text, handling accents, chinese chars and lowercasing This includes cleaning the text, handling accents, chinese chars and lowercasing
""" """
@staticmethod def __init__(
def new(clean_text: Optional[bool]=True, self,
handle_chinese_chars: Optional[bool]=True, clean_text: Optional[bool] = True,
strip_accents: Optional[bool]=True, handle_chinese_chars: Optional[bool] = True,
lowercase: Optional[bool]=True) -> Normalizer: strip_accents: Optional[bool] = True,
lowercase: Optional[bool] = True,
) -> None:
""" Instantiate a BertNormalizer with the given options. """ Instantiate a BertNormalizer with the given options.
Args: Args:
@ -43,32 +45,28 @@ class BertNormalizer(Normalizer):
class NFD(Normalizer): class NFD(Normalizer):
""" NFD Unicode Normalizer """ """ NFD Unicode Normalizer """
@staticmethod def __init__(self) -> None:
def new() -> Normalizer:
""" Instantiate a new NFD Normalizer """ """ Instantiate a new NFD Normalizer """
pass pass
class NFKD(Normalizer): class NFKD(Normalizer):
""" NFKD Unicode Normalizer """ """ NFKD Unicode Normalizer """
@staticmethod def __init__(self) -> None:
def new() -> Normalizer:
""" Instantiate a new NFKD Normalizer """ """ Instantiate a new NFKD Normalizer """
pass pass
class NFC(Normalizer): class NFC(Normalizer):
""" NFC Unicode Normalizer """ """ NFC Unicode Normalizer """
@staticmethod def __init__(self) -> None:
def new() -> Normalizer:
""" Instantiate a new NFC Normalizer """ """ Instantiate a new NFC Normalizer """
pass pass
class NFKC(Normalizer): class NFKC(Normalizer):
""" NFKC Unicode Normalizer """ """ NFKC Unicode Normalizer """
@staticmethod def __init__(self) -> None:
def new() -> Normalizer:
""" Instantiate a new NFKC Normalizer """ """ Instantiate a new NFKC Normalizer """
pass pass
@ -78,8 +76,7 @@ class Sequence(Normalizer):
All the normalizers run in sequence in the given order All the normalizers run in sequence in the given order
""" """
@staticmethod def __init__(self, normalizers: List[Normalizer]) -> None:
def new(normalizers: List[Normalizer]) -> Normalizer:
""" Instantiate a new normalization Sequence using the given normalizers """ Instantiate a new normalization Sequence using the given normalizers
Args: Args:
@ -91,16 +88,14 @@ class Sequence(Normalizer):
class Lowercase(Normalizer): class Lowercase(Normalizer):
""" Lowercase Normalizer """ """ Lowercase Normalizer """
@staticmethod def __init__(self) -> None:
def new() -> Normalizer:
""" Instantiate a new Lowercase Normalizer """ """ Instantiate a new Lowercase Normalizer """
pass pass
def unicode_normalizer_from_str(normalizer: str) -> Normalizer: def unicode_normalizer_from_str(normalizer: str) -> Normalizer:
""" """
Instanciate unicode normalizer from the normalizer name Instanciate unicode normalizer from the normalizer name
:param normalizer: Name of the normalizer :param normalizer: Name of the normalizer
:return: :return:
""" """
pass pass