Formatting after rebase.

This commit is contained in:
Nicolas Patry
2020-08-24 13:50:11 +02:00
parent 439305eea0
commit e974cfb1c9
4 changed files with 6 additions and 18 deletions

View File

@ -8,8 +8,7 @@ TextInputSequence = str
PreTokenizedInputSequence = Union[List[str], Tuple[str]] PreTokenizedInputSequence = Union[List[str], Tuple[str]]
TextEncodeInput = Union[TextInputSequence, Tuple[TextInputSequence, TextInputSequence]] TextEncodeInput = Union[TextInputSequence, Tuple[TextInputSequence, TextInputSequence]]
PreTokenizedEncodeInput = Union[ PreTokenizedEncodeInput = Union[
PreTokenizedInputSequence, PreTokenizedInputSequence, Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence],
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence],
] ]
InputSequence = Union[TextInputSequence, PreTokenizedInputSequence] InputSequence = Union[TextInputSequence, PreTokenizedInputSequence]

View File

@ -13,10 +13,7 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
""" """
def __init__( def __init__(
self, self, vocab: Optional[str] = None, replacement: str = "", add_prefix_space: bool = True,
vocab: Optional[str] = None,
replacement: str = "",
add_prefix_space: bool = True,
): ):
if vocab is not None: if vocab is not None:
tokenizer = Tokenizer(Unigram(vocab)) tokenizer = Tokenizer(Unigram(vocab))

View File

@ -7,17 +7,10 @@ import json
def main(): def main():
parser = ArgumentParser("SentencePiece parity checker") parser = ArgumentParser("SentencePiece parity checker")
parser.add_argument( parser.add_argument(
"--input-file", "--input-file", "-i", type=str, required=True, help="Which files do you want to train from",
"-i",
type=str,
required=True,
help="Which files do you want to train from",
) )
parser.add_argument( parser.add_argument(
"--model-prefix", "--model-prefix", type=str, default="spm_parity", help="Model prefix for spm_train",
type=str,
default="spm_parity",
help="Model prefix for spm_train",
) )
parser.add_argument( parser.add_argument(
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train", "--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
@ -57,9 +50,7 @@ def main():
if len(ids) != len(encoded.ids): if len(ids) != len(encoded.ids):
N = len(ids) N = len(ids)
M = len(encoded.ids) M = len(encoded.ids)
first_index_error = [ first_index_error = [i for i in range(min(N, M)) if ids[i] != encoded.ids[i]][0]
i for i in range(min(N, M)) if ids[i] != encoded.ids[i]
][0]
last_index_error = [ last_index_error = [
min(N, M) - i min(N, M) - i
for i in range(min(N, M)) for i in range(min(N, M))

View File

@ -0,0 +1 @@