Black *Version* check.

This commit is contained in:
Nicolas Patry
2020-09-23 11:39:22 +02:00
parent 9b1ef9d895
commit 35ee1968c0
5 changed files with 30 additions and 9 deletions

View File

@@ -9,7 +9,8 @@ TextInputSequence = str
PreTokenizedInputSequence = Union[List[str], Tuple[str]] PreTokenizedInputSequence = Union[List[str], Tuple[str]]
TextEncodeInput = Union[TextInputSequence, Tuple[TextInputSequence, TextInputSequence]] TextEncodeInput = Union[TextInputSequence, Tuple[TextInputSequence, TextInputSequence]]
PreTokenizedEncodeInput = Union[ PreTokenizedEncodeInput = Union[
PreTokenizedInputSequence, Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence], PreTokenizedInputSequence,
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence],
] ]
InputSequence = Union[TextInputSequence, PreTokenizedInputSequence] InputSequence = Union[TextInputSequence, PreTokenizedInputSequence]

View File

@@ -21,7 +21,8 @@ TextInputSequence = str
PreTokenizedInputSequence = Union[List[str], Tuple[str]] PreTokenizedInputSequence = Union[List[str], Tuple[str]]
TextEncodeInput = Union[TextInputSequence, Tuple[TextInputSequence, TextInputSequence]] TextEncodeInput = Union[TextInputSequence, Tuple[TextInputSequence, TextInputSequence]]
PreTokenizedEncodeInput = Union[ PreTokenizedEncodeInput = Union[
PreTokenizedInputSequence, Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence], PreTokenizedInputSequence,
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence],
] ]
InputSequence = Union[TextInputSequence, PreTokenizedInputSequence] InputSequence = Union[TextInputSequence, PreTokenizedInputSequence]
@@ -827,7 +828,10 @@ class Tokenizer:
""" """
pass pass
def post_process( def post_process(
self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True, self,
encoding: Encoding,
pair: Optional[Encoding] = None,
add_special_tokens: bool = True,
) -> Encoding: ) -> Encoding:
"""Apply all the post-processing steps to the given encodings. """Apply all the post-processing steps to the given encodings.

View File

@@ -21,7 +21,10 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
""" """
def __init__( def __init__(
self, vocab: Optional[str] = None, replacement: str = "", add_prefix_space: bool = True, self,
vocab: Optional[str] = None,
replacement: str = "",
add_prefix_space: bool = True,
): ):
if vocab is not None: if vocab is not None:
# Let Unigram(..) fail if only one of them is None # Let Unigram(..) fail if only one of them is None
@@ -29,7 +32,12 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
else: else:
tokenizer = Tokenizer(Unigram()) tokenizer = Tokenizer(Unigram())
tokenizer.normalizer = normalizers.Sequence([normalizers.Nmt(), normalizers.NFKC(),]) tokenizer.normalizer = normalizers.Sequence(
[
normalizers.Nmt(),
normalizers.NFKC(),
]
)
tokenizer.pre_tokenizer = pre_tokenizers.Sequence( tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[ [
pre_tokenizers.WhitespaceSplit(), pre_tokenizers.WhitespaceSplit(),
@@ -60,7 +68,9 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
""" Train the model using the given files """ """ Train the model using the given files """
trainer = trainers.UnigramTrainer( trainer = trainers.UnigramTrainer(
vocab_size=vocab_size, special_tokens=special_tokens, show_progress=show_progress, vocab_size=vocab_size,
special_tokens=special_tokens,
show_progress=show_progress,
) )
if isinstance(files, str): if isinstance(files, str):

View File

@@ -18,7 +18,10 @@ class TestBPE:
BPE(vocab=vocab) BPE(vocab=vocab)
BPE(merges=merges) BPE(merges=merges)
assert isinstance(pickle.loads(pickle.dumps(BPE(vocab, merges))), BPE,) assert isinstance(
pickle.loads(pickle.dumps(BPE(vocab, merges))),
BPE,
)
# Deprecated calls in 0.9 # Deprecated calls in 0.9
with pytest.deprecated_call(): with pytest.deprecated_call():

View File

@@ -22,7 +22,8 @@ class TestBertProcessing:
assert isinstance(processor, PostProcessor) assert isinstance(processor, PostProcessor)
assert isinstance(processor, BertProcessing) assert isinstance(processor, BertProcessing)
assert isinstance( assert isinstance(
pickle.loads(pickle.dumps(BertProcessing(("[SEP]", 0), ("[CLS]", 1)))), BertProcessing, pickle.loads(pickle.dumps(BertProcessing(("[SEP]", 0), ("[CLS]", 1)))),
BertProcessing,
) )
def test_processing(self): def test_processing(self):
@@ -94,7 +95,9 @@ class TestTemplateProcessing:
def get_roberta(self): def get_roberta(self):
return TemplateProcessing( return TemplateProcessing(
seq_a="<s> $0 </s>", seq_b="</s> $0 </s>", special_tokens=[("<s>", 0), ("</s>", 1)], seq_a="<s> $0 </s>",
seq_b="</s> $0 </s>",
special_tokens=[("<s>", 0), ("</s>", 1)],
) )
def get_t5_squad(self): def get_t5_squad(self):