Fixing convert/check scripts.

This commit is contained in:
Nicolas Patry
2020-09-18 11:51:46 +02:00
parent c0b9229833
commit c59b216baa
2 changed files with 6 additions and 4 deletions

View File

@ -397,7 +397,7 @@ def check(pretrained, filename):
tok_total_time += tok - trans tok_total_time += tok - trans
if ids != tok_ids: if ids != tok_ids:
if check_details(line, ids, tok_ids, tokenizer, transformer_tokenizer): if check_details(line, ids, tok_ids, transformer_tokenizer, tokenizer):
continue continue
assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}" assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}"

View File

@ -137,7 +137,7 @@ def check_diff(spm_diff, tok_diff, sp, tok):
return False return False
def check_details(line, spm_ids, tok_ids, tok, sp): def check_details(line, spm_ids, tok_ids, sp, tok):
# Encoding can be the same with same result AAA -> A + AA vs AA + A # Encoding can be the same with same result AAA -> A + AA vs AA + A
# We can check that we use at least exactly the same number of tokens. # We can check that we use at least exactly the same number of tokens.
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)): for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
@ -173,7 +173,9 @@ def check_details(line, spm_ids, tok_ids, tok, sp):
for j in possible_matches: for j in possible_matches:
if check_diff( if check_diff(
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
) and check_diff(spm_ids[first + i : last], tok_ids[first + j : last], sp, tok): ) and check_details(
line, spm_ids[first + i : last], tok_ids[first + j : last], sp, tok,
):
return True return True
print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}") print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
@ -241,7 +243,7 @@ def check_encode(args):
print(f"SPM: {spm_total_time} - TOK: {tok_total_time}") print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")
if ids != encoded.ids: if ids != encoded.ids:
if check_details(line, ids, encoded.ids, tok, sp): if check_details(line, ids, encoded.ids, sp, tok):
imperfect += 1 imperfect += 1
continue continue
else: else: