mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Adressing first pass of comments.
This commit is contained in:
@ -76,43 +76,28 @@ class SpmConverter(Converter):
|
||||
model_type = proto.trainer_spec.model_type
|
||||
vocab = self.vocab(proto)
|
||||
unk_id = self.unk_id(proto)
|
||||
filename = self.original_tokenizer.vocab_file
|
||||
|
||||
if model_type == 1:
|
||||
data = {"unk_id": unk_id, "vocab": vocab}
|
||||
|
||||
out_vocab_filename = f"{filename}.json"
|
||||
try:
|
||||
with open(out_vocab_filename, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
tokenizer = Tokenizer(Unigram(out_vocab_filename))
|
||||
finally:
|
||||
os.remove(out_vocab_filename)
|
||||
tokenizer = Tokenizer(Unigram(vocab, unk_id))
|
||||
elif model_type == 2:
|
||||
vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
|
||||
vocab, merges = SentencePieceExtractor(
|
||||
self.original_tokenizer.vocab_file
|
||||
).extract()
|
||||
# Open output files and let's extract model information
|
||||
out_vocab_filename = f"{filename}.vocab"
|
||||
out_merge_filename = f"{filename}.merge"
|
||||
try:
|
||||
with open(out_vocab_filename, "w") as vocab_f:
|
||||
json.dump(vocab, vocab_f)
|
||||
try:
|
||||
with open(out_merge_filename, "w") as merges_f:
|
||||
# Save content
|
||||
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{os.linesep}", merges))
|
||||
tokenizer = Tokenizer(
|
||||
BPE(
|
||||
out_vocab_filename,
|
||||
out_merge_filename,
|
||||
unk_token=proto.trainer_spec.unk_piece,
|
||||
fuse_unk=True,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
os.remove(out_merge_filename)
|
||||
finally:
|
||||
os.remove(out_vocab_filename)
|
||||
actual_merges = {}
|
||||
for id_merge, (a, b) in enumerate(merges):
|
||||
id_a = vocab[a]
|
||||
id_b = vocab[b]
|
||||
id_ab = vocab[a + b]
|
||||
id_ab = vocab[a + b]
|
||||
actual_merges[(id_a, id_b)] = (id_merge, id_ab)
|
||||
tokenizer = Tokenizer(
|
||||
BPE(
|
||||
vocab,
|
||||
actual_merges,
|
||||
unk_token=proto.trainer_spec.unk_piece,
|
||||
fuse_unk=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||
@ -346,7 +331,9 @@ class PegasusConverter(SpmConverter):
|
||||
return TemplateProcessing(
|
||||
seq_a=["$0", eos],
|
||||
seq_b=["$1", eos],
|
||||
special_tokens=[(eos, tokenizer.get_vocab()[eos]),],
|
||||
special_tokens=[
|
||||
(eos, tokenizer.get_vocab()[eos]),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -355,7 +342,9 @@ class T5Converter(SpmConverter):
|
||||
return TemplateProcessing(
|
||||
seq_a=["$0", "</s>"],
|
||||
seq_b=["$1", "</s>"],
|
||||
special_tokens=[("</s>", tokenizer.get_vocab()["</s>"]),],
|
||||
special_tokens=[
|
||||
("</s>", tokenizer.get_vocab()["</s>"]),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -447,7 +436,9 @@ def main():
|
||||
model_len = 50
|
||||
status_len = 6
|
||||
speedup_len = 8
|
||||
print(f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|")
|
||||
print(
|
||||
f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|"
|
||||
)
|
||||
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
|
||||
for pretrained in args.models:
|
||||
status, speedup = check(pretrained, args.filename)
|
||||
|
@ -17,7 +17,11 @@ except Exception:
|
||||
def main():
|
||||
parser = ArgumentParser("SentencePiece parity checker")
|
||||
parser.add_argument(
|
||||
"--input-file", "-i", type=str, required=True, help="Which files do you want to train from",
|
||||
"--input-file",
|
||||
"-i",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Which files do you want to train from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-file",
|
||||
@ -28,13 +32,22 @@ def main():
|
||||
help="Use a pretrained token file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-prefix", type=str, default="spm_parity", help="Model prefix for spm_train",
|
||||
"--model-prefix",
|
||||
type=str,
|
||||
default="spm_parity",
|
||||
help="Model prefix for spm_train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
|
||||
"--vocab-size",
|
||||
"-v",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Vocab size for spm_train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Verbosity",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Verbosity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train",
|
||||
@ -160,10 +173,14 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
|
||||
spms = Counter(spm_ids[first:last])
|
||||
toks = Counter(tok_ids[first:last])
|
||||
|
||||
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
|
||||
removable_tokens = {
|
||||
spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si
|
||||
}
|
||||
min_width = 3
|
||||
for i in range(last - first - min_width):
|
||||
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
|
||||
if all(
|
||||
spm_ids[first + i + j] in removable_tokens for j in range(min_width)
|
||||
):
|
||||
possible_matches = [
|
||||
k
|
||||
for k in range(last - first - min_width)
|
||||
@ -174,7 +191,11 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
|
||||
if check_diff(
|
||||
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
|
||||
) and check_details(
|
||||
line, spm_ids[first + i : last], tok_ids[first + j : last], sp, tok,
|
||||
line,
|
||||
spm_ids[first + i : last],
|
||||
tok_ids[first + j : last],
|
||||
sp,
|
||||
tok,
|
||||
):
|
||||
return True
|
||||
|
||||
@ -189,7 +210,9 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
|
||||
wrong = tok.decode(spm_ids[first:last])
|
||||
print()
|
||||
if has_color:
|
||||
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
|
||||
print(
|
||||
f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}"
|
||||
)
|
||||
else:
|
||||
print(wrong)
|
||||
return False
|
||||
@ -203,17 +226,8 @@ def check_encode(args):
|
||||
tok = tokenizers.SentencePieceUnigramTokenizer.from_spm(args.model_file)
|
||||
else:
|
||||
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
|
||||
vocab_filename = f"{args.model_file}.json"
|
||||
unk_id = sp.unk_id()
|
||||
|
||||
data = {"unk_id": unk_id, "vocab": vocab}
|
||||
try:
|
||||
with open(vocab_filename, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename)
|
||||
finally:
|
||||
os.remove(vocab_filename)
|
||||
tok = tokenizers.SentencePieceUnigramTokenizer(vocab, unk_id)
|
||||
|
||||
perfect = 0
|
||||
imperfect = 0
|
||||
@ -255,7 +269,9 @@ def check_encode(args):
|
||||
|
||||
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
|
||||
total = perfect + imperfect + wrong
|
||||
print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}")
|
||||
print(
|
||||
f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user