mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 08:45:38 +00:00
Making convert script machine agnostic.
This commit is contained in:
@ -22,6 +22,7 @@ import unicodedata
|
||||
import sys
|
||||
import os
|
||||
import datetime
|
||||
import argparse
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
@ -361,18 +362,16 @@ CONVERTERS = {
|
||||
}
|
||||
|
||||
|
||||
def check(pretrained):
|
||||
def check(pretrained, filename):
|
||||
transformer_tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
|
||||
converter_class = CONVERTERS[transformer_tokenizer.__class__.__name__]
|
||||
tokenizer = converter_class(transformer_tokenizer).converted()
|
||||
|
||||
tokenizer.save(f"{pretrained}.json")
|
||||
|
||||
now = datetime.datetime.now
|
||||
trans_total_time = datetime.timedelta(seconds=0)
|
||||
tok_total_time = datetime.timedelta(seconds=0)
|
||||
|
||||
with open("/home/nicolas/data/xnli/xnli.txt", "r") as f:
|
||||
with open(filename, "r") as f:
|
||||
for i, line in enumerate(f):
|
||||
line = line.strip()
|
||||
|
||||
@ -397,6 +396,7 @@ def check(pretrained):
|
||||
continue
|
||||
assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}"
|
||||
|
||||
tokenizer.save(f"{pretrained.replace('/', '-')}.json")
|
||||
return ("OK", trans_total_time / tok_total_time)
|
||||
|
||||
|
||||
@ -425,14 +425,30 @@ def main():
|
||||
"t5-small",
|
||||
"google/pegasus-large",
|
||||
]
|
||||
pretraineds = ["google/pegasus-large"]
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The filename that we are going to encode in both versions to check that conversion worked",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
type=lambda s: s.split(","),
|
||||
default=pretraineds,
|
||||
help=f"The pretrained tokenizers you want to test agains, (default: {pretraineds})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(args.filename)
|
||||
|
||||
model_len = 50
|
||||
status_len = 6
|
||||
speedup_len = 8
|
||||
print(f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|")
|
||||
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
|
||||
for pretrained in pretraineds:
|
||||
status, speedup = check(pretrained)
|
||||
for pretrained in args.models:
|
||||
status, speedup = check(pretrained, args.filename)
|
||||
print(
|
||||
f"|{pretrained:<{model_len}}|{status:^{status_len}}|{speedup:^{speedup_len - 1}.2f}x|"
|
||||
)
|
||||
|
Reference in New Issue
Block a user