mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
[remove black
] And use ruff (#1436)
* nits * Fixing deps. * Ruff update. * Import order matters. * Fix. * Revert ruff fix. * Visualizer. * Putting back the imports. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@ -80,9 +80,7 @@ class SpmConverter(Converter):
|
||||
tokenizer = Tokenizer(Unigram(vocab, unk_id))
|
||||
elif model_type == 2:
|
||||
vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
|
||||
tokenizer = Tokenizer(
|
||||
BPE(vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True)
|
||||
)
|
||||
tokenizer = Tokenizer(BPE(vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True))
|
||||
else:
|
||||
raise Exception(
|
||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||
@ -105,12 +103,8 @@ class SpmConverter(Converter):
|
||||
|
||||
replacement = "▁"
|
||||
add_prefix_space = True
|
||||
tokenizer.pre_tokenizer = Metaspace(
|
||||
replacement=replacement, add_prefix_space=add_prefix_space
|
||||
)
|
||||
tokenizer.decoder = decoders.Metaspace(
|
||||
replacement=replacement, add_prefix_space=add_prefix_space
|
||||
)
|
||||
tokenizer.pre_tokenizer = Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
|
||||
tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
|
||||
post_processor = self.post_processor(tokenizer)
|
||||
if post_processor:
|
||||
tokenizer.post_processor = post_processor
|
||||
@ -124,9 +118,7 @@ class SpmConverter(Converter):
|
||||
class AlbertConverter(SpmConverter):
|
||||
def vocab(self, proto):
|
||||
return [
|
||||
(piece.piece, piece.score)
|
||||
if check_number_comma(piece.piece)
|
||||
else (piece.piece, piece.score - 100)
|
||||
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
|
||||
for piece in proto.pieces
|
||||
]
|
||||
|
||||
@ -261,9 +253,7 @@ class XLMRobertaConverter(SpmConverter):
|
||||
class XLNetConverter(SpmConverter):
|
||||
def vocab(self, proto):
|
||||
return [
|
||||
(piece.piece, piece.score)
|
||||
if check_number_comma(piece.piece)
|
||||
else (piece.piece, piece.score - 100)
|
||||
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
|
||||
for piece in proto.pieces
|
||||
]
|
||||
|
||||
@ -420,9 +410,7 @@ def main():
|
||||
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
|
||||
for pretrained in args.models:
|
||||
status, speedup = check(pretrained, args.filename)
|
||||
print(
|
||||
f"|{pretrained:<{model_len}}|{status:^{status_len}}|{speedup:^{speedup_len - 1}.2f}x|"
|
||||
)
|
||||
print(f"|{pretrained:<{model_len}}|{status:^{status_len}}|{speedup:^{speedup_len - 1}.2f}x|")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -59,7 +59,6 @@ class YouTokenToMeExtractor:
|
||||
|
||||
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
|
||||
with open(self._model, "r") as model_f:
|
||||
|
||||
# Retrieve information
|
||||
nb_pieces, nb_merges = map(int, model_f.readline().split())
|
||||
vocab, merges = {}, []
|
||||
@ -97,9 +96,7 @@ if __name__ == "__main__":
|
||||
choices=["sentencepiece", "youtokentome"],
|
||||
help="Indicate the format of the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, required=True, help="SentencePiece model to extract vocab from."
|
||||
)
|
||||
parser.add_argument("--model", type=str, required=True, help="SentencePiece model to extract vocab from.")
|
||||
parser.add_argument(
|
||||
"--vocab-output-path",
|
||||
type=str,
|
||||
@ -128,9 +125,7 @@ if __name__ == "__main__":
|
||||
args.model = f.name
|
||||
|
||||
# Allocate extractor
|
||||
extractor = (
|
||||
SentencePieceExtractor if args.provider == "sentencepiece" else YouTokenToMeExtractor
|
||||
)
|
||||
extractor = SentencePieceExtractor if args.provider == "sentencepiece" else YouTokenToMeExtractor
|
||||
extractor = extractor(args.model)
|
||||
|
||||
logger.info(f"Using {type(extractor).__name__}")
|
||||
|
@ -121,9 +121,7 @@ def check_train(args):
|
||||
break
|
||||
|
||||
print(f"Tokenizer used {tokenizer_tokens}, where spm used {spm_tokens}")
|
||||
assert (
|
||||
tokenizer_tokens < spm_tokens
|
||||
), "Our trainer should be at least more efficient than the SPM one"
|
||||
assert tokenizer_tokens < spm_tokens, "Our trainer should be at least more efficient than the SPM one"
|
||||
print("Ok our trainer is at least more efficient than the SPM one")
|
||||
|
||||
|
||||
@ -131,9 +129,7 @@ def check_diff(spm_diff, tok_diff, sp, tok):
|
||||
if spm_diff == list(reversed(tok_diff)):
|
||||
# AAA -> AA+A vs A+AA case.
|
||||
return True
|
||||
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(
|
||||
tok_diff
|
||||
):
|
||||
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(tok_diff):
|
||||
# Second order OK
|
||||
# Barrich -> Barr + ich vs Bar + rich
|
||||
return True
|
||||
@ -173,24 +169,17 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
|
||||
spms = Counter(spm_ids[first:last])
|
||||
toks = Counter(tok_ids[first:last])
|
||||
|
||||
removable_tokens = {
|
||||
spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si
|
||||
}
|
||||
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
|
||||
min_width = 3
|
||||
for i in range(last - first - min_width):
|
||||
if all(
|
||||
spm_ids[first + i + j] in removable_tokens for j in range(min_width)
|
||||
):
|
||||
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
|
||||
possible_matches = [
|
||||
k
|
||||
for k in range(last - first - min_width)
|
||||
if tok_ids[first + k : first + k + min_width]
|
||||
== spm_ids[first + i : first + i + min_width]
|
||||
if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width]
|
||||
]
|
||||
for j in possible_matches:
|
||||
if check_diff(
|
||||
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
|
||||
) and check_details(
|
||||
if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details(
|
||||
line,
|
||||
spm_ids[first + i : last],
|
||||
tok_ids[first + j : last],
|
||||
@ -210,9 +199,7 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
|
||||
wrong = tok.decode(spm_ids[first:last])
|
||||
print()
|
||||
if has_color:
|
||||
print(
|
||||
f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}"
|
||||
)
|
||||
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
|
||||
else:
|
||||
print(wrong)
|
||||
return False
|
||||
@ -251,9 +238,7 @@ def check_encode(args):
|
||||
|
||||
if args.verbose:
|
||||
if i % 10000 == 0:
|
||||
print(
|
||||
f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})"
|
||||
)
|
||||
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
|
||||
print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")
|
||||
|
||||
if ids != encoded.ids:
|
||||
@ -265,13 +250,13 @@ def check_encode(args):
|
||||
else:
|
||||
perfect += 1
|
||||
|
||||
assert ids == encoded.ids, f"line {i}: {line} : \n\n{ids}\n{encoded.ids}\n{list(zip(encoded.ids, encoded.tokens))}"
|
||||
assert (
|
||||
ids == encoded.ids
|
||||
), f"line {i}: {line} : \n\n{ids}\n{encoded.ids}\n{list(zip(encoded.ids, encoded.tokens))}"
|
||||
|
||||
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
|
||||
total = perfect + imperfect + wrong
|
||||
print(
|
||||
f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}"
|
||||
)
|
||||
print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user