Files
tokenizers/bindings/python/scripts/spm_parity_check.py
Nicolas Patry 330876ae02 Improvements on spm parity: (#401)
* Removing all pre_tokenizer logic from Unigram algorithm.

* Improving *a lot* the parity check.

- We can now detect a lot more errors
- Special cases have been added temporarily.

* Adding 2 new normalizers that mimick spm defaut's behavior.

* Adding `encoding_optimized` version of the `encode` algorithm.

- Removes Lattice allocation.
- Changes trie `common_prefix_search` to return an iterator to avoid
  allocation of the full results.

* Trie<char> -> Trie<u8> Another improvement on speed.

* [WIP] Attempt to create a Precompiled Normalizer from SPM to be 100%
compliant with arbitrary models.

* Adding a new `Precompiled` Normalizer that is replacing `SpmNmtNfkc`.

- It will be used for direct compatiblity with `Spm` and replace all
their custom rules by using directly the normalizer spec embedded
within spm files, removing all need for any rules for us.
- We need `nom` dependency to parse the binary format of `spm`.
- We need to add `sentencepiece_model_pb2.py` file to be able to read
  the proto file.
- We reimplemented their `Darts::DoubleArray` compact trie format.

* Fixing a bug with Precompiled normalizer.

* Fixing some edge cases (now in tests) with this weird precompiled
normalizer.

It seems a very handy crafted trie does not prevent from shooting
oneself in the foot. Sorry future reader.

* Keep API stable for this PR (change of the API should come later #409).

- Removed sentencepiece_model_pb2 from binding and add instructions to
make `from_spm` work.

* Adding model check in `from_spm`.

* Adressing @n1t0's comments.

* Adding a check to make sure alignments stay correct.

Also added a bit more documentation on how Precompiled works.

* Extracting `Precompiled` into it's own `spm_precompiled` crate.

* Using ranges in `do_nmt`.
2020-09-15 22:21:02 +02:00

255 lines
8.2 KiB
Python

import tokenizers
from argparse import ArgumentParser
import sentencepiece as spm
from collections import Counter
import json
import os
import datetime
try:
from termcolor import colored
has_color = True
except Exception:
has_color = False
def main():
parser = ArgumentParser("SentencePiece parity checker")
parser.add_argument(
"--input-file", "-i", type=str, required=True, help="Which files do you want to train from",
)
parser.add_argument(
"--model-file",
"-m",
type=str,
required=False,
default=None,
help="Use a pretrained token file",
)
parser.add_argument(
"--model-prefix", type=str, default="spm_parity", help="Model prefix for spm_train",
)
parser.add_argument(
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
)
parser.add_argument(
"--verbose", action="store_true", help="Verbosity",
)
parser.add_argument(
"--train",
action="store_true",
help="Instead of checking the encoder part, we check the trainer part",
)
parser.add_argument(
"--from-spm",
action="store_true",
help="Directly load the spm file with it's own normalizer",
)
args = parser.parse_args()
trained = False
if args.model_file is None:
spm.SentencePieceTrainer.Train(
f"--input={args.input_file} --model_prefix={args.model_prefix}"
f" --character_coverage=1.0"
f" --max_sentence_length=40000"
f" --num_threads=1"
f" --vocab_size={args.vocab_size}"
)
trained = True
args.model_file = f"{args.model_prefix}.model"
try:
if args.train:
check_train(args)
else:
check_encode(args)
finally:
if trained:
os.remove(f"{args.model_prefix}.model")
os.remove(f"{args.model_prefix}.vocab")
def check_train(args):
sp = spm.SentencePieceProcessor()
sp.Load(args.model_file)
tokenizer = tokenizers.SentencePieceUnigramTokenizer()
tokenizer.train(args.input_file, show_progress=False)
spm_tokens = 0
tokenizer_tokens = 0
with open(args.input_file, "r") as f:
for i, line in enumerate(f):
line = line.strip()
ids = sp.EncodeAsIds(line)
encoded = tokenizer.encode(line)
spm_tokens += len(ids)
tokenizer_tokens += len(encoded.ids)
vocab = [0 for i in range(args.vocab_size)]
spm_vocab = [0 for i in range(args.vocab_size)]
for token, index in tokenizer.get_vocab().items():
vocab[index] = token
for i in range(args.vocab_size):
spm_vocab[i] = sp.id_to_piece(i)
# 0 is unk in tokenizers, 0, 1, 2 are unk bos, eos in spm by default.
for i, (token, spm_token) in enumerate(zip(vocab[1:], spm_vocab[3:])):
if token != spm_token:
print(f"First different token is token {i} ({token} != {spm_token})")
break
print(f"Tokenizer used {tokenizer_tokens}, where spm used {spm_tokens}")
assert (
tokenizer_tokens < spm_tokens
), "Our trainer should be at least more efficient than the SPM one"
print("Ok our trainer is at least more efficient than the SPM one")
def check_diff(spm_diff, tok_diff, sp, tok):
if spm_diff == list(reversed(tok_diff)):
# AAA -> AA+A vs A+AA case.
return True
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(tok_diff):
# Second order OK
# Barrich -> Barr + ich vs Bar + rich
return True
spm_reencoded = sp.encode(sp.decode(spm_diff))
tok_reencoded = tok.encode(tok.decode(spm_diff)).ids
if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
# Type 3 error.
# Snehagatha ->
# Sne, h, aga, th, a
# Sne, ha, gat, ha
# Encoding the wrong with sp does not even recover what spm gave us
# It fits tokenizer however...
return True
return False
def check_details(line, spm_ids, tok_ids, tok, sp):
# Encoding can be the same with same result AAA -> A + AA vs AA + A
# We can check that we use at least exactly the same number of tokens.
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
if spm_id != tok_id:
break
first = i
for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))):
if spm_id != tok_id:
break
last = len(spm_ids) - i
spm_diff = spm_ids[first:last]
tok_diff = tok_ids[first:last]
if check_diff(spm_diff, tok_diff, sp, tok):
return True
if last - first > 5:
# We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems
spms = Counter(spm_ids[first:last])
toks = Counter(tok_ids[first:last])
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
min_width = 3
for i in range(last - first - min_width):
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
possible_matches = [
k
for k in range(last - first - min_width)
if tok_ids[first + k : first + k + min_width]
== spm_ids[first + i : first + i + min_width]
]
for j in possible_matches:
if check_diff(
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
) and check_diff(spm_ids[first + i : last], tok_ids[first + j : last], sp, tok):
return True
ok_start = tok.decode(spm_ids[:first])
ok_end = tok.decode(spm_ids[last:])
wrong = tok.decode(spm_ids[first:last])
print()
if has_color:
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
else:
print(wrong)
print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
print(f"Tok: {[tok.decode([tok_ids[i]]) for i in range(first, last)]}")
return False
def check_encode(args):
sp = spm.SentencePieceProcessor()
sp.Load(args.model_file)
if args.from_spm:
tok = tokenizers.SentencePieceUnigramTokenizer.from_spm(args.model_file)
else:
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
vocab_filename = f"{args.model_file}.json"
unk_id = sp.unk_id()
data = {"unk_id": unk_id, "vocab": vocab}
try:
with open(vocab_filename, "w") as f:
json.dump(data, f, indent=4)
tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename)
finally:
os.remove(vocab_filename)
perfect = 0
imperfect = 0
wrong = 0
now = datetime.datetime.now
spm_total_time = datetime.timedelta(seconds=0)
tok_total_time = datetime.timedelta(seconds=0)
with open(args.input_file, "r", encoding="utf-8-sig") as f:
for i, line in enumerate(f):
line = line.strip()
start = now()
ids = sp.EncodeAsIds(line)
spm_time = now()
encoded = tok.encode(line)
tok_time = now()
spm_total_time += spm_time - start
tok_total_time += tok_time - spm_time
if args.verbose:
if i % 10000 == 0:
print(
f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})"
)
print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")
if ids != encoded.ids:
if check_details(line, ids, encoded.ids, tok, sp):
imperfect += 1
continue
else:
wrong += 1
else:
perfect += 1
assert ids == encoded.ids, f"line {i}: {line} : {ids} != {encoded.ids}"
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
total = perfect + imperfect + wrong
print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}")
if __name__ == "__main__":
main()