Some suggestions from @narsil

This commit is contained in:
Anthony MOI
2020-09-22 11:59:40 -04:00
committed by Anthony MOI
parent 31b81f109b
commit b24a2fc178
6 changed files with 60 additions and 22 deletions

View File

@@ -1,6 +1,8 @@
import jieba
from tokenizers import Tokenizer, Regex
from typing import List
from tokenizers import Tokenizer, Regex, NormalizedString, PreTokenizedString
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import PreTokenizer
from tokenizers.normalizers import Normalizer
@@ -8,24 +10,52 @@ from tokenizers.decoders import Decoder
class JiebaPreTokenizer:
def jieba_split(self, i, normalized):
return [normalized[w[1] : w[2]] for w in jieba.tokenize(str(normalized))]
def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
splits = []
# we need to call `str(normalized_string)` because jieba expects a str,
# not a NormalizedString
for token, start, stop in jieba.tokenize(str(normalized_string)):
splits.append(normalized_string[start:stop])
def pre_tokenize(self, pretok):
# Let's call split on the PreTokenizedString to split using `self.split`
# Here we can call `pretok.split` multiple times if we want to apply
# different algorithm
return splits
# We can also easily do it in one line:
# return [normalized_string[w[1] : w[2]] for w in jieba.tokenize(str(normalized_string))]
def odd_number_split(
self, i: int, normalized_string: NormalizedString
) -> List[NormalizedString]:
# Just an odd example...
splits = []
last = 0
for (i, char) in enumerate(str(normalized_string)):
if char.isnumeric() and int(char) % 2 == 1:
splits.append(normalized_string[last:i])
last = i
# Don't forget the last one
splits.append(normalized_string[last:])
return splits
def pre_tokenize(self, pretok: PreTokenizedString):
# Let's call split on the PreTokenizedString to split using `self.jieba_split`
pretok.split(self.jieba_split)
# Here we can call `pretok.split` multiple times if we want to apply
# different algorithm, but we generally just need to call it once.
pretok.split(self.odd_number_split)
class CustomDecoder:
def decode(self, tokens):
def decode(self, tokens: List[str]) -> str:
return "".join(tokens)
class CustomNormalizer:
def normalize(self, normalized):
def normalize(self, normalized: NormalizedString):
# Most of these can be replaced by a `Sequence` combining some provided Normalizer,
# (ie Sequence([ NFKC(), Replace(Regex("\s+"), " "), Lowercase() ])
# and it should be the prefered way. That being said, here is an example of the kind
# of things that can be done here:
normalized.nfkc()
normalized.filter(lambda char: not char.isnumeric())
normalized.replace(Regex("\s+"), " ")
normalized.lowercase()
@@ -36,12 +66,17 @@ tok.normalizer = Normalizer.custom(CustomNormalizer())
tok.pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer())
tok.decoder = Decoder.custom(CustomDecoder())
input1 = "永和服装饰品有限公司"
print("PreTokenize:", input1)
print(tok.pre_tokenizer.pre_tokenize_str(input1))
input = "永和服装饰品有限公司"
print("PreTokenize:", input)
print(tok.pre_tokenizer.pre_tokenize_str(input))
# [('永和', (0, 2)), ('服装', (2, 4)), ('饰品', (4, 6)), ('有限公司', (6, 10))]
input2 = "𝔢𝔩𝔩𝔬 𝔱𝔥𝔢𝔯𝔢 𝓂𝓎 𝒹𝒶𝓇 𝕕𝕖𝕒𝕣 𝕗𝕣𝕚𝕖𝕟𝕕!"
print("Normalize:", input2)
print(tok.normalizer.normalize_str(input2))
# hello there my dear dear friend!
input = "112233"
print("PreTokenize:", input)
print(tok.pre_tokenizer.pre_tokenize_str(input))
# [('1', (0, 1)), ('122', (1, 4)), ('3', (4, 5)), ('3', (5, 6))]
input = "1234 𝔢𝔩𝔩𝔬 𝔱𝔥𝔢𝔯𝔢 𝓂𝓎 𝒹𝒶𝓇 𝕕𝕖𝕒𝕣 𝕗𝕣𝕚𝕖𝕟𝕕!"
print("Normalize:", input)
print(tok.normalizer.normalize_str(input))
# " hello there my dear dear friend!"

View File

@@ -73,11 +73,11 @@ class PreTokenizedString:
The string sequence used to initialize this PreTokenizedString
"""
pass
def split(self, func: Callable[[NormalizedString], List[NormalizedString]]):
def split(self, func: Callable[[index, NormalizedString], List[NormalizedString]]):
""" Split the PreTokenizedString using the given `func`
Args:
func: Callable[[NormalizedString], List[NormalizedString]]:
func: Callable[[index, NormalizedString], List[NormalizedString]]:
The function used to split each underlying split.
It is expected to return a list of `NormalizedString`, that represent the new
splits. If the given `NormalizedString` does not need any splitting, we can

View File

@@ -273,8 +273,8 @@ impl PyObjectProtocol<'p> for PyNormalizedString {
fn __repr__(&self) -> String {
format!(
r#"NormalizedString(original="{}", normalized="{}")"#,
self.normalized.get(),
self.normalized.get_original()
self.normalized.get_original(),
self.normalized.get()
)
}

View File

@@ -16,7 +16,7 @@ fn split(pretok: &mut PreTokenizedString, func: &PyAny) -> PyResult<()> {
if !func.is_callable() {
Err(exceptions::PyTypeError::new_err(
"`split` expect a callable with the signature: \
`fn(normalized: NormalizedString) -> List[NormalizedString]`",
`fn(index: int, normalized: NormalizedString) -> List[NormalizedString]`",
))
} else {
ToPyResult(pretok.split(|i, normalized| {

View File

@@ -130,5 +130,5 @@ class TestCustomNormalizer:
normalized = NormalizedString("Hey there!")
normalizer.normalize(normalized)
assert repr(normalized) == 'NormalizedString(original="Hey you!", normalized="Hey there!")'
assert repr(normalized) == 'NormalizedString(original="Hey there!", normalized="Hey you!")'
assert str(normalized) == "Hey you!"

View File

@@ -124,10 +124,13 @@ class TestDigits:
class TestCustomPreTokenizer:
class BadCustomPretok:
def pre_tokenize(self, pretok, wrong):
# This method does not have the right signature: it takes one too many arg
pass
class GoodCustomPretok:
def split(self, n, normalized):
# Here we just test that we can return a List[NormalizedString], it
# does not really make sense to return twice the same otherwise
return [normalized, normalized]
def pre_tokenize(self, pretok):