Adressing first pass of comments.

This commit is contained in:
Nicolas Patry
2020-09-23 11:29:17 +02:00
parent 1cd4824273
commit 8f8156fd2c
10 changed files with 196 additions and 125 deletions

View File

@@ -17,7 +17,11 @@ except Exception:
def main():
parser = ArgumentParser("SentencePiece parity checker")
parser.add_argument(
"--input-file", "-i", type=str, required=True, help="Which files do you want to train from",
"--input-file",
"-i",
type=str,
required=True,
help="Which files do you want to train from",
)
parser.add_argument(
"--model-file",
@@ -28,13 +32,22 @@ def main():
help="Use a pretrained token file",
)
parser.add_argument(
"--model-prefix", type=str, default="spm_parity", help="Model prefix for spm_train",
"--model-prefix",
type=str,
default="spm_parity",
help="Model prefix for spm_train",
)
parser.add_argument(
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
"--vocab-size",
"-v",
type=int,
default=8000,
help="Vocab size for spm_train",
)
parser.add_argument(
"--verbose", action="store_true", help="Verbosity",
"--verbose",
action="store_true",
help="Verbosity",
)
parser.add_argument(
"--train",
@@ -160,10 +173,14 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
spms = Counter(spm_ids[first:last])
toks = Counter(tok_ids[first:last])
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
removable_tokens = {
spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si
}
min_width = 3
for i in range(last - first - min_width):
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
if all(
spm_ids[first + i + j] in removable_tokens for j in range(min_width)
):
possible_matches = [
k
for k in range(last - first - min_width)
@@ -174,7 +191,11 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
if check_diff(
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
) and check_details(
line, spm_ids[first + i : last], tok_ids[first + j : last], sp, tok,
line,
spm_ids[first + i : last],
tok_ids[first + j : last],
sp,
tok,
):
return True
@@ -189,7 +210,9 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
wrong = tok.decode(spm_ids[first:last])
print()
if has_color:
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
print(
f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}"
)
else:
print(wrong)
return False
@@ -203,17 +226,8 @@ def check_encode(args):
tok = tokenizers.SentencePieceUnigramTokenizer.from_spm(args.model_file)
else:
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
vocab_filename = f"{args.model_file}.json"
unk_id = sp.unk_id()
data = {"unk_id": unk_id, "vocab": vocab}
try:
with open(vocab_filename, "w") as f:
json.dump(data, f, indent=4)
tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename)
finally:
os.remove(vocab_filename)
tok = tokenizers.SentencePieceUnigramTokenizer(vocab, unk_id)
perfect = 0
imperfect = 0
@@ -255,7 +269,9 @@ def check_encode(args):
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
total = perfect + imperfect + wrong
print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}")
print(
f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}"
)
if __name__ == "__main__":