[remove black] And use ruff (#1436)

* nits

* Fixing deps.

* Ruff update.

* Import order matters.

* Fix.

* Revert ruff fix.

* Visualizer.

* Putting back the imports.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Arthur
2024-03-12 21:24:21 +11:00
committed by GitHub
parent 72a1973cd1
commit 29fef1e7aa
29 changed files with 258 additions and 169 deletions

View File

@ -80,9 +80,7 @@ class SpmConverter(Converter):
tokenizer = Tokenizer(Unigram(vocab, unk_id))
elif model_type == 2:
vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
tokenizer = Tokenizer(
BPE(vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True)
)
tokenizer = Tokenizer(BPE(vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True))
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
@ -105,12 +103,8 @@ class SpmConverter(Converter):
replacement = ""
add_prefix_space = True
tokenizer.pre_tokenizer = Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space
)
tokenizer.decoder = decoders.Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space
)
tokenizer.pre_tokenizer = Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
post_processor = self.post_processor(tokenizer)
if post_processor:
tokenizer.post_processor = post_processor
@ -124,9 +118,7 @@ class SpmConverter(Converter):
class AlbertConverter(SpmConverter):
def vocab(self, proto):
return [
(piece.piece, piece.score)
if check_number_comma(piece.piece)
else (piece.piece, piece.score - 100)
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
for piece in proto.pieces
]
@ -261,9 +253,7 @@ class XLMRobertaConverter(SpmConverter):
class XLNetConverter(SpmConverter):
def vocab(self, proto):
return [
(piece.piece, piece.score)
if check_number_comma(piece.piece)
else (piece.piece, piece.score - 100)
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
for piece in proto.pieces
]
@ -420,9 +410,7 @@ def main():
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
for pretrained in args.models:
status, speedup = check(pretrained, args.filename)
print(
f"|{pretrained:<{model_len}}|{status:^{status_len}}|{speedup:^{speedup_len - 1}.2f}x|"
)
print(f"|{pretrained:<{model_len}}|{status:^{status_len}}|{speedup:^{speedup_len - 1}.2f}x|")
if __name__ == "__main__":

View File

@ -59,7 +59,6 @@ class YouTokenToMeExtractor:
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
with open(self._model, "r") as model_f:
# Retrieve information
nb_pieces, nb_merges = map(int, model_f.readline().split())
vocab, merges = {}, []
@ -97,9 +96,7 @@ if __name__ == "__main__":
choices=["sentencepiece", "youtokentome"],
help="Indicate the format of the file.",
)
parser.add_argument(
"--model", type=str, required=True, help="SentencePiece model to extract vocab from."
)
parser.add_argument("--model", type=str, required=True, help="SentencePiece model to extract vocab from.")
parser.add_argument(
"--vocab-output-path",
type=str,
@ -128,9 +125,7 @@ if __name__ == "__main__":
args.model = f.name
# Allocate extractor
extractor = (
SentencePieceExtractor if args.provider == "sentencepiece" else YouTokenToMeExtractor
)
extractor = SentencePieceExtractor if args.provider == "sentencepiece" else YouTokenToMeExtractor
extractor = extractor(args.model)
logger.info(f"Using {type(extractor).__name__}")

View File

@ -121,9 +121,7 @@ def check_train(args):
break
print(f"Tokenizer used {tokenizer_tokens}, where spm used {spm_tokens}")
assert (
tokenizer_tokens < spm_tokens
), "Our trainer should be at least more efficient than the SPM one"
assert tokenizer_tokens < spm_tokens, "Our trainer should be at least more efficient than the SPM one"
print("Ok our trainer is at least more efficient than the SPM one")
@ -131,9 +129,7 @@ def check_diff(spm_diff, tok_diff, sp, tok):
if spm_diff == list(reversed(tok_diff)):
# AAA -> AA+A vs A+AA case.
return True
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(
tok_diff
):
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(tok_diff):
# Second order OK
# Barrich -> Barr + ich vs Bar + rich
return True
@ -173,24 +169,17 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
spms = Counter(spm_ids[first:last])
toks = Counter(tok_ids[first:last])
removable_tokens = {
spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si
}
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
min_width = 3
for i in range(last - first - min_width):
if all(
spm_ids[first + i + j] in removable_tokens for j in range(min_width)
):
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
possible_matches = [
k
for k in range(last - first - min_width)
if tok_ids[first + k : first + k + min_width]
== spm_ids[first + i : first + i + min_width]
if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width]
]
for j in possible_matches:
if check_diff(
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
) and check_details(
if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details(
line,
spm_ids[first + i : last],
tok_ids[first + j : last],
@ -210,9 +199,7 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
wrong = tok.decode(spm_ids[first:last])
print()
if has_color:
print(
f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}"
)
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
else:
print(wrong)
return False
@ -251,9 +238,7 @@ def check_encode(args):
if args.verbose:
if i % 10000 == 0:
print(
f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})"
)
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")
if ids != encoded.ids:
@ -265,13 +250,13 @@ def check_encode(args):
else:
perfect += 1
assert ids == encoded.ids, f"line {i}: {line} : \n\n{ids}\n{encoded.ids}\n{list(zip(encoded.ids, encoded.tokens))}"
assert (
ids == encoded.ids
), f"line {i}: {line} : \n\n{ids}\n{encoded.ids}\n{list(zip(encoded.ids, encoded.tokens))}"
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
total = perfect + imperfect + wrong
print(
f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}"
)
print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}")
if __name__ == "__main__":