Fixing convert/check scripts.

This commit is contained in:
Nicolas Patry
2020-09-18 11:51:46 +02:00
parent c0b9229833
commit c59b216baa
2 changed files with 6 additions and 4 deletions

View File

@ -397,7 +397,7 @@ def check(pretrained, filename):
tok_total_time += tok - trans
if ids != tok_ids:
if check_details(line, ids, tok_ids, tokenizer, transformer_tokenizer):
if check_details(line, ids, tok_ids, transformer_tokenizer, tokenizer):
continue
assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}"

View File

@ -137,7 +137,7 @@ def check_diff(spm_diff, tok_diff, sp, tok):
return False
def check_details(line, spm_ids, tok_ids, tok, sp):
def check_details(line, spm_ids, tok_ids, sp, tok):
# Encoding can be the same with same result AAA -> A + AA vs AA + A
# We can check that we use at least exactly the same number of tokens.
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
@ -173,7 +173,9 @@ def check_details(line, spm_ids, tok_ids, tok, sp):
for j in possible_matches:
if check_diff(
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
) and check_diff(spm_ids[first + i : last], tok_ids[first + j : last], sp, tok):
) and check_details(
line, spm_ids[first + i : last], tok_ids[first + j : last], sp, tok,
):
return True
print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
@ -241,7 +243,7 @@ def check_encode(args):
print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")
if ids != encoded.ids:
if check_details(line, ids, encoded.ids, tok, sp):
if check_details(line, ids, encoded.ids, sp, tok):
imperfect += 1
continue
else: