Adressing first pass of comments.

This commit is contained in:
Nicolas Patry
2020-09-23 11:29:17 +02:00
parent 1cd4824273
commit 8f8156fd2c
10 changed files with 196 additions and 125 deletions

View File

@ -76,43 +76,28 @@ class SpmConverter(Converter):
model_type = proto.trainer_spec.model_type
vocab = self.vocab(proto)
unk_id = self.unk_id(proto)
filename = self.original_tokenizer.vocab_file
if model_type == 1:
data = {"unk_id": unk_id, "vocab": vocab}
out_vocab_filename = f"{filename}.json"
try:
with open(out_vocab_filename, "w") as f:
json.dump(data, f, indent=4)
tokenizer = Tokenizer(Unigram(out_vocab_filename))
finally:
os.remove(out_vocab_filename)
tokenizer = Tokenizer(Unigram(vocab, unk_id))
elif model_type == 2:
vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
vocab, merges = SentencePieceExtractor(
self.original_tokenizer.vocab_file
).extract()
# Open output files and let's extract model information
out_vocab_filename = f"{filename}.vocab"
out_merge_filename = f"{filename}.merge"
try:
with open(out_vocab_filename, "w") as vocab_f:
json.dump(vocab, vocab_f)
try:
with open(out_merge_filename, "w") as merges_f:
# Save content
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{os.linesep}", merges))
tokenizer = Tokenizer(
BPE(
out_vocab_filename,
out_merge_filename,
unk_token=proto.trainer_spec.unk_piece,
fuse_unk=True,
)
)
finally:
os.remove(out_merge_filename)
finally:
os.remove(out_vocab_filename)
actual_merges = {}
for id_merge, (a, b) in enumerate(merges):
id_a = vocab[a]
id_b = vocab[b]
id_ab = vocab[a + b]
id_ab = vocab[a + b]
actual_merges[(id_a, id_b)] = (id_merge, id_ab)
tokenizer = Tokenizer(
BPE(
vocab,
actual_merges,
unk_token=proto.trainer_spec.unk_piece,
fuse_unk=True,
)
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
@ -346,7 +331,9 @@ class PegasusConverter(SpmConverter):
return TemplateProcessing(
seq_a=["$0", eos],
seq_b=["$1", eos],
special_tokens=[(eos, tokenizer.get_vocab()[eos]),],
special_tokens=[
(eos, tokenizer.get_vocab()[eos]),
],
)
@ -355,7 +342,9 @@ class T5Converter(SpmConverter):
return TemplateProcessing(
seq_a=["$0", "</s>"],
seq_b=["$1", "</s>"],
special_tokens=[("</s>", tokenizer.get_vocab()["</s>"]),],
special_tokens=[
("</s>", tokenizer.get_vocab()["</s>"]),
],
)
@ -447,7 +436,9 @@ def main():
model_len = 50
status_len = 6
speedup_len = 8
print(f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|")
print(
f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|"
)
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
for pretrained in args.models:
status, speedup = check(pretrained, args.filename)