mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-29 03:29:22 +00:00
Fixing convert/check scripts.
This commit is contained in:
@ -397,7 +397,7 @@ def check(pretrained, filename):
|
|||||||
tok_total_time += tok - trans
|
tok_total_time += tok - trans
|
||||||
|
|
||||||
if ids != tok_ids:
|
if ids != tok_ids:
|
||||||
if check_details(line, ids, tok_ids, tokenizer, transformer_tokenizer):
|
if check_details(line, ids, tok_ids, transformer_tokenizer, tokenizer):
|
||||||
continue
|
continue
|
||||||
assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}"
|
assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}"
|
||||||
|
|
||||||
|
@ -137,7 +137,7 @@ def check_diff(spm_diff, tok_diff, sp, tok):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def check_details(line, spm_ids, tok_ids, tok, sp):
|
def check_details(line, spm_ids, tok_ids, sp, tok):
|
||||||
# Encoding can be the same with same result AAA -> A + AA vs AA + A
|
# Encoding can be the same with same result AAA -> A + AA vs AA + A
|
||||||
# We can check that we use at least exactly the same number of tokens.
|
# We can check that we use at least exactly the same number of tokens.
|
||||||
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
|
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
|
||||||
@ -173,7 +173,9 @@ def check_details(line, spm_ids, tok_ids, tok, sp):
|
|||||||
for j in possible_matches:
|
for j in possible_matches:
|
||||||
if check_diff(
|
if check_diff(
|
||||||
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
|
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
|
||||||
) and check_diff(spm_ids[first + i : last], tok_ids[first + j : last], sp, tok):
|
) and check_details(
|
||||||
|
line, spm_ids[first + i : last], tok_ids[first + j : last], sp, tok,
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
|
print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
|
||||||
@ -241,7 +243,7 @@ def check_encode(args):
|
|||||||
print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")
|
print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")
|
||||||
|
|
||||||
if ids != encoded.ids:
|
if ids != encoded.ids:
|
||||||
if check_details(line, ids, encoded.ids, tok, sp):
|
if check_details(line, ids, encoded.ids, sp, tok):
|
||||||
imperfect += 1
|
imperfect += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user