From c59b216baaa84a64a000dc6fa898c3c77ae12cc5 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 18 Sep 2020 11:51:46 +0200 Subject: [PATCH] Fixing convert/check scripts. --- bindings/python/scripts/convert.py | 2 +- bindings/python/scripts/spm_parity_check.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/bindings/python/scripts/convert.py b/bindings/python/scripts/convert.py index 82d44bbb..668b0515 100644 --- a/bindings/python/scripts/convert.py +++ b/bindings/python/scripts/convert.py @@ -397,7 +397,7 @@ def check(pretrained, filename): tok_total_time += tok - trans if ids != tok_ids: - if check_details(line, ids, tok_ids, tokenizer, transformer_tokenizer): + if check_details(line, ids, tok_ids, transformer_tokenizer, tokenizer): continue assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}" diff --git a/bindings/python/scripts/spm_parity_check.py b/bindings/python/scripts/spm_parity_check.py index 2fed7860..5821f0d9 100644 --- a/bindings/python/scripts/spm_parity_check.py +++ b/bindings/python/scripts/spm_parity_check.py @@ -137,7 +137,7 @@ def check_diff(spm_diff, tok_diff, sp, tok): return False -def check_details(line, spm_ids, tok_ids, tok, sp): +def check_details(line, spm_ids, tok_ids, sp, tok): # Encoding can be the same with same result AAA -> A + AA vs AA + A # We can check that we use at least exactly the same number of tokens. for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)): @@ -173,7 +173,9 @@ def check_details(line, spm_ids, tok_ids, tok, sp): for j in possible_matches: if check_diff( spm_ids[first : first + i], tok_ids[first : first + j], sp, tok - ) and check_diff(spm_ids[first + i : last], tok_ids[first + j : last], sp, tok): + ) and check_details( + line, spm_ids[first + i : last], tok_ids[first + j : last], sp, tok, + ): return True print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}") @@ -241,7 +243,7 @@ def check_encode(args): print(f"SPM: {spm_total_time} - TOK: {tok_total_time}") if ids != encoded.ids: - if check_details(line, ids, encoded.ids, tok, sp): + if check_details(line, ids, encoded.ids, sp, tok): imperfect += 1 continue else: