[remove black] And use ruff (#1436)

* nits

* Fixing deps.

* Ruff update.

* Import order matters.

* Fix.

* Revert ruff fix.

* Visualizer.

* Putting back the imports.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Arthur
2024-03-12 21:24:21 +11:00
committed by GitHub
parent 72a1973cd1
commit 29fef1e7aa
29 changed files with 258 additions and 169 deletions

View File

@ -121,9 +121,7 @@ def check_train(args):
break
print(f"Tokenizer used {tokenizer_tokens}, where spm used {spm_tokens}")
assert (
tokenizer_tokens < spm_tokens
), "Our trainer should be at least more efficient than the SPM one"
assert tokenizer_tokens < spm_tokens, "Our trainer should be at least more efficient than the SPM one"
print("Ok our trainer is at least more efficient than the SPM one")
@ -131,9 +129,7 @@ def check_diff(spm_diff, tok_diff, sp, tok):
if spm_diff == list(reversed(tok_diff)):
# AAA -> AA+A vs A+AA case.
return True
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(
tok_diff
):
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(tok_diff):
# Second order OK
# Barrich -> Barr + ich vs Bar + rich
return True
@ -173,24 +169,17 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
spms = Counter(spm_ids[first:last])
toks = Counter(tok_ids[first:last])
removable_tokens = {
spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si
}
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
min_width = 3
for i in range(last - first - min_width):
if all(
spm_ids[first + i + j] in removable_tokens for j in range(min_width)
):
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
possible_matches = [
k
for k in range(last - first - min_width)
if tok_ids[first + k : first + k + min_width]
== spm_ids[first + i : first + i + min_width]
if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width]
]
for j in possible_matches:
if check_diff(
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
) and check_details(
if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details(
line,
spm_ids[first + i : last],
tok_ids[first + j : last],
@ -210,9 +199,7 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
wrong = tok.decode(spm_ids[first:last])
print()
if has_color:
print(
f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}"
)
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
else:
print(wrong)
return False
@ -251,9 +238,7 @@ def check_encode(args):
if args.verbose:
if i % 10000 == 0:
print(
f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})"
)
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")
if ids != encoded.ids:
@ -265,13 +250,13 @@ def check_encode(args):
else:
perfect += 1
assert ids == encoded.ids, f"line {i}: {line} : \n\n{ids}\n{encoded.ids}\n{list(zip(encoded.ids, encoded.tokens))}"
assert (
ids == encoded.ids
), f"line {i}: {line} : \n\n{ids}\n{encoded.ids}\n{list(zip(encoded.ids, encoded.tokens))}"
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
total = perfect + imperfect + wrong
print(
f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}"
)
print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}")
if __name__ == "__main__":