mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-31 04:29:21 +00:00
Making convert script machine agnostic.
This commit is contained in:
@ -22,6 +22,7 @@ import unicodedata
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
|
import argparse
|
||||||
|
|
||||||
sys.path.append(".")
|
sys.path.append(".")
|
||||||
|
|
||||||
@ -361,18 +362,16 @@ CONVERTERS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def check(pretrained):
|
def check(pretrained, filename):
|
||||||
transformer_tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
|
transformer_tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
|
||||||
converter_class = CONVERTERS[transformer_tokenizer.__class__.__name__]
|
converter_class = CONVERTERS[transformer_tokenizer.__class__.__name__]
|
||||||
tokenizer = converter_class(transformer_tokenizer).converted()
|
tokenizer = converter_class(transformer_tokenizer).converted()
|
||||||
|
|
||||||
tokenizer.save(f"{pretrained}.json")
|
|
||||||
|
|
||||||
now = datetime.datetime.now
|
now = datetime.datetime.now
|
||||||
trans_total_time = datetime.timedelta(seconds=0)
|
trans_total_time = datetime.timedelta(seconds=0)
|
||||||
tok_total_time = datetime.timedelta(seconds=0)
|
tok_total_time = datetime.timedelta(seconds=0)
|
||||||
|
|
||||||
with open("/home/nicolas/data/xnli/xnli.txt", "r") as f:
|
with open(filename, "r") as f:
|
||||||
for i, line in enumerate(f):
|
for i, line in enumerate(f):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
|
|
||||||
@ -397,6 +396,7 @@ def check(pretrained):
|
|||||||
continue
|
continue
|
||||||
assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}"
|
assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}"
|
||||||
|
|
||||||
|
tokenizer.save(f"{pretrained.replace('/', '-')}.json")
|
||||||
return ("OK", trans_total_time / tok_total_time)
|
return ("OK", trans_total_time / tok_total_time)
|
||||||
|
|
||||||
|
|
||||||
@ -425,14 +425,30 @@ def main():
|
|||||||
"t5-small",
|
"t5-small",
|
||||||
"google/pegasus-large",
|
"google/pegasus-large",
|
||||||
]
|
]
|
||||||
pretraineds = ["google/pegasus-large"]
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="The filename that we are going to encode in both versions to check that conversion worked",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
type=lambda s: s.split(","),
|
||||||
|
default=pretraineds,
|
||||||
|
help=f"The pretrained tokenizers you want to test agains, (default: {pretraineds})",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(args.filename)
|
||||||
|
|
||||||
model_len = 50
|
model_len = 50
|
||||||
status_len = 6
|
status_len = 6
|
||||||
speedup_len = 8
|
speedup_len = 8
|
||||||
print(f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|")
|
print(f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|")
|
||||||
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
|
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
|
||||||
for pretrained in pretraineds:
|
for pretrained in args.models:
|
||||||
status, speedup = check(pretrained)
|
status, speedup = check(pretrained, args.filename)
|
||||||
print(
|
print(
|
||||||
f"|{pretrained:<{model_len}}|{status:^{status_len}}|{speedup:^{speedup_len - 1}.2f}x|"
|
f"|{pretrained:<{model_len}}|{status:^{status_len}}|{speedup:^{speedup_len - 1}.2f}x|"
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user