mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
* Fixing doc. * SentencePieceUnigram and Convert.py still used sentencepiece * stub --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
418 lines
13 KiB
Python
418 lines
13 KiB
Python
import transformers
|
|
from tokenizers.implementations import SentencePieceUnigramTokenizer, BaseTokenizer
|
|
from tokenizers.processors import TemplateProcessing
|
|
from tokenizers.models import Unigram, BPE
|
|
from tokenizers import decoders
|
|
from tokenizers import Tokenizer, Regex
|
|
from tokenizers.normalizers import (
|
|
StripAccents,
|
|
NFKD,
|
|
Lowercase,
|
|
Sequence,
|
|
BertNormalizer,
|
|
Precompiled,
|
|
Replace,
|
|
)
|
|
from tokenizers.pre_tokenizers import (
|
|
Digits,
|
|
WhitespaceSplit,
|
|
Metaspace,
|
|
Sequence as PSequence,
|
|
)
|
|
import json
|
|
import unicodedata
|
|
import sys
|
|
import os
|
|
import datetime
|
|
import argparse
|
|
|
|
sys.path.append(".")
|
|
|
|
from spm_parity_check import check_details
|
|
from sentencepiece_extractor import SentencePieceExtractor
|
|
|
|
|
|
def check_number_comma(piece: str) -> bool:
|
|
return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
|
|
|
|
|
|
def get_proto(filename: str):
|
|
try:
|
|
import sys
|
|
|
|
sys.path.append(".")
|
|
|
|
import sentencepiece_model_pb2 as model
|
|
except Exception:
|
|
raise Exception(
|
|
"You don't seem to have the required protobuf file, in order to use this function you need to run `pip install protobuf` and `wget https://raw.githubusercontent.com/google/sentencepiece/master/python/sentencepiece_model_pb2.py` for us to be able to read the intrinsics of your spm_file. `pip install sentencepiece` is not required."
|
|
)
|
|
|
|
m = model.ModelProto()
|
|
m.ParseFromString(open(filename, "rb").read())
|
|
return m
|
|
|
|
|
|
class Converter:
|
|
def __init__(self, original_tokenizer):
|
|
self.original_tokenizer = original_tokenizer
|
|
|
|
def converted(self) -> Tokenizer:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class SpmConverter(Converter):
|
|
def __init__(self, *args):
|
|
super().__init__(*args)
|
|
self.proto = get_proto(self.original_tokenizer.vocab_file)
|
|
|
|
def vocab(self, proto):
|
|
return [(piece.piece, piece.score) for piece in proto.pieces]
|
|
|
|
def unk_id(self, proto):
|
|
return proto.trainer_spec.unk_id
|
|
|
|
def tokenizer(self, proto):
|
|
model_type = proto.trainer_spec.model_type
|
|
vocab = self.vocab(proto)
|
|
unk_id = self.unk_id(proto)
|
|
if model_type == 1:
|
|
tokenizer = Tokenizer(Unigram(vocab, unk_id))
|
|
elif model_type == 2:
|
|
vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
|
|
tokenizer = Tokenizer(BPE(vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True))
|
|
else:
|
|
raise Exception(
|
|
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
|
)
|
|
|
|
return tokenizer
|
|
|
|
def normalizer(self, proto):
|
|
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
|
return Sequence([Precompiled(precompiled_charsmap), Replace(Regex(" {2,}"), " ")])
|
|
|
|
def post_processor(self, tokenizer):
|
|
return None
|
|
|
|
def converted(self):
|
|
tokenizer = self.tokenizer(self.proto)
|
|
|
|
# Tokenizer assemble
|
|
tokenizer.normalizer = self.normalizer(self.proto)
|
|
|
|
replacement = "▁"
|
|
prepend_scheme = "always"
|
|
tokenizer.pre_tokenizer = Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
|
|
tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
|
|
post_processor = self.post_processor(tokenizer)
|
|
if post_processor:
|
|
tokenizer.post_processor = post_processor
|
|
|
|
# TODO what parameters should we give ?
|
|
parameters = {}
|
|
|
|
return BaseTokenizer(tokenizer, parameters)
|
|
|
|
|
|
class AlbertConverter(SpmConverter):
|
|
def vocab(self, proto):
|
|
return [
|
|
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
|
|
for piece in proto.pieces
|
|
]
|
|
|
|
def normalizer(self, proto):
|
|
normalizers = [Replace("``", '"'), Replace("''", '"')]
|
|
if not self.original_tokenizer.keep_accents:
|
|
normalizers.append(NFKD())
|
|
normalizers.append(StripAccents())
|
|
if self.original_tokenizer.do_lower_case:
|
|
normalizers.append(Lowercase())
|
|
|
|
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
|
normalizers.append(Precompiled(precompiled_charsmap))
|
|
normalizers.append(Replace(Regex(" {2,}"), " "))
|
|
return Sequence(normalizers)
|
|
|
|
def post_processor(self, tokenizer):
|
|
return TemplateProcessing(
|
|
seq_a=["[CLS]", "$0", "[SEP]"],
|
|
seq_b=["$1", "[SEP]"],
|
|
special_tokens=[
|
|
("[CLS]", tokenizer.get_vocab()["[CLS]"]),
|
|
("[SEP]", tokenizer.get_vocab()["[SEP]"]),
|
|
],
|
|
)
|
|
|
|
|
|
class CamembertConverter(SpmConverter):
|
|
def vocab(self, proto):
|
|
vocab = [
|
|
("<s>NOTUSED", 0.0),
|
|
("<pad>", 0.0),
|
|
("</s>NOTUSED", 0.0),
|
|
("<unk>", 0.0),
|
|
]
|
|
vocab += [(piece.piece, piece.score) for piece in proto.pieces]
|
|
return vocab
|
|
|
|
def unk_id(self, proto):
|
|
# See vocab unk position
|
|
return 3
|
|
|
|
def post_processor(self, tokenizer):
|
|
return TemplateProcessing(
|
|
seq_a=["<s>", "$0", "</s>"],
|
|
seq_b=["$1", "</s>"],
|
|
special_tokens=[
|
|
("<s>", tokenizer.get_vocab()["<s>"]),
|
|
("</s>", tokenizer.get_vocab()["</s>"]),
|
|
],
|
|
)
|
|
|
|
|
|
class MBartConverter(SpmConverter):
|
|
def vocab(self, proto):
|
|
vocab = [
|
|
("<s>", 0.0),
|
|
("<pad>", 0.0),
|
|
("</s>", 0.0),
|
|
("<unk>", 0.0),
|
|
]
|
|
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
|
vocab += [
|
|
("ar_AR", 0.0),
|
|
("cs_CZ", 0.0),
|
|
("de_DE", 0.0),
|
|
("en_XX", 0.0),
|
|
("es_XX", 0.0),
|
|
("et_EE", 0.0),
|
|
("fi_FI", 0.0),
|
|
("fr_XX", 0.0),
|
|
("gu_IN", 0.0),
|
|
("hi_IN", 0.0),
|
|
("it_IT", 0.0),
|
|
("ja_XX", 0.0),
|
|
("kk_KZ", 0.0),
|
|
("ko_KR", 0.0),
|
|
("lt_LT", 0.0),
|
|
("lv_LV", 0.0),
|
|
("my_MM", 0.0),
|
|
("ne_NP", 0.0),
|
|
("nl_XX", 0.0),
|
|
("ro_RO", 0.0),
|
|
("ru_RU", 0.0),
|
|
("si_LK", 0.0),
|
|
("tr_TR", 0.0),
|
|
("vi_VN", 0.0),
|
|
("zh_CN", 0.0),
|
|
]
|
|
return vocab
|
|
|
|
def unk_id(self, proto):
|
|
return 3
|
|
|
|
def post_processor(self, tokenizer):
|
|
return TemplateProcessing(
|
|
seq_a=["$0", "</s>", "en_XX"],
|
|
seq_b=["$1", "</s>"],
|
|
special_tokens=[
|
|
("en_XX", tokenizer.get_vocab()["en_XX"]),
|
|
("</s>", tokenizer.get_vocab()["</s>"]),
|
|
],
|
|
)
|
|
|
|
|
|
class XLMRobertaConverter(SpmConverter):
|
|
def vocab(self, proto):
|
|
vocab = [
|
|
("<s>", 0.0),
|
|
("<pad>", 0.0),
|
|
("</s>", 0.0),
|
|
("<unk>", 0.0),
|
|
]
|
|
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
|
return vocab
|
|
|
|
def unk_id(self, proto):
|
|
unk_id = 3
|
|
return unk_id
|
|
|
|
def post_processor(self, tokenizer):
|
|
return TemplateProcessing(
|
|
seq_a=["<s>", "$0", "</s>"],
|
|
seq_b=["$1", "</s>"],
|
|
special_tokens=[
|
|
("<s>", tokenizer.get_vocab()["<s>"]),
|
|
("</s>", tokenizer.get_vocab()["</s>"]),
|
|
],
|
|
)
|
|
|
|
|
|
class XLNetConverter(SpmConverter):
|
|
def vocab(self, proto):
|
|
return [
|
|
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
|
|
for piece in proto.pieces
|
|
]
|
|
|
|
def normalizer(self, proto):
|
|
normalizers = [Replace("``", '"'), Replace("''", '"')]
|
|
if not self.original_tokenizer.keep_accents:
|
|
normalizers.append(NFKD())
|
|
normalizers.append(StripAccents())
|
|
if self.original_tokenizer.do_lower_case:
|
|
normalizers.append(Lowercase())
|
|
|
|
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
|
normalizers.append(Precompiled(precompiled_charsmap))
|
|
normalizers.append(Replace(Regex(" {2,}"), " "))
|
|
return Sequence(normalizers)
|
|
|
|
def post_processor(self, tokenizer):
|
|
return TemplateProcessing(
|
|
seq_a=["$0", "<sep>", "<cls>"],
|
|
seq_b=["$1", "<sep>"],
|
|
special_tokens=[
|
|
("<sep>", tokenizer.get_vocab()["<sep>"]),
|
|
("<cls>", tokenizer.get_vocab()["<cls>"]),
|
|
],
|
|
)
|
|
|
|
|
|
class ReformerConverter(SpmConverter):
|
|
pass
|
|
|
|
|
|
class PegasusConverter(SpmConverter):
|
|
offset = 103
|
|
|
|
def vocab(self, proto):
|
|
vocab = [
|
|
(self.original_tokenizer.pad_token, 0),
|
|
(self.original_tokenizer.eos_token, 0),
|
|
]
|
|
vocab += [(f"unk_{i}", -100) for i in range(2, 2 + self.offset)]
|
|
vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
|
|
return vocab
|
|
|
|
def unk_id(self, proto):
|
|
return proto.trainer_spec.unk_id + self.offset
|
|
|
|
def post_processor(self, tokenizer):
|
|
eos = self.original_tokenizer.eos_token
|
|
return TemplateProcessing(
|
|
seq_a=["$0", eos],
|
|
seq_b=["$1", eos],
|
|
special_tokens=[(eos, tokenizer.get_vocab()[eos])],
|
|
)
|
|
|
|
|
|
class T5Converter(SpmConverter):
|
|
def post_processor(self, tokenizer):
|
|
return TemplateProcessing(
|
|
seq_a=["$0", "</s>"],
|
|
seq_b=["$1", "</s>"],
|
|
special_tokens=[("</s>", tokenizer.get_vocab()["</s>"])],
|
|
)
|
|
|
|
|
|
CONVERTERS = {
|
|
"AlbertTokenizer": AlbertConverter,
|
|
"CamembertTokenizer": CamembertConverter,
|
|
"XLMRobertaTokenizer": XLMRobertaConverter,
|
|
"MBartTokenizer": MBartConverter,
|
|
"XLNetTokenizer": XLNetConverter,
|
|
"ReformerTokenizer": ReformerConverter,
|
|
"PegasusTokenizer": PegasusConverter,
|
|
"T5Tokenizer": T5Converter,
|
|
}
|
|
|
|
|
|
def check(pretrained, filename):
|
|
transformer_tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
|
|
converter_class = CONVERTERS[transformer_tokenizer.__class__.__name__]
|
|
tokenizer = converter_class(transformer_tokenizer).converted()
|
|
|
|
now = datetime.datetime.now
|
|
trans_total_time = datetime.timedelta(seconds=0)
|
|
tok_total_time = datetime.timedelta(seconds=0)
|
|
|
|
with open(filename, "r") as f:
|
|
for i, line in enumerate(f):
|
|
line = line.strip()
|
|
|
|
start = now()
|
|
ids = transformer_tokenizer.encode(line)
|
|
trans = now()
|
|
tok_ids = tokenizer.encode(line).ids
|
|
tok = now()
|
|
|
|
trans_total_time += trans - start
|
|
tok_total_time += tok - trans
|
|
|
|
if ids != tok_ids:
|
|
if check_details(line, ids, tok_ids, transformer_tokenizer, tokenizer):
|
|
continue
|
|
assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}"
|
|
|
|
tokenizer.save(f"{pretrained.replace('/', '-')}.json")
|
|
return ("OK", trans_total_time / tok_total_time)
|
|
|
|
|
|
def main():
|
|
pretraineds = [
|
|
"albert-base-v1",
|
|
"albert-large-v1",
|
|
"albert-xlarge-v1",
|
|
"albert-xxlarge-v1",
|
|
"albert-base-v2",
|
|
"albert-large-v2",
|
|
"albert-xlarge-v2",
|
|
"albert-xxlarge-v2",
|
|
"camembert-base",
|
|
"xlm-roberta-base",
|
|
"xlm-roberta-large",
|
|
"xlm-roberta-large-finetuned-conll02-dutch",
|
|
"xlm-roberta-large-finetuned-conll02-spanish",
|
|
"xlm-roberta-large-finetuned-conll03-english",
|
|
"xlm-roberta-large-finetuned-conll03-german",
|
|
"facebook/mbart-large-en-ro",
|
|
"facebook/mbart-large-cc25",
|
|
"xlnet-base-cased",
|
|
"xlnet-large-cased",
|
|
"google/reformer-crime-and-punishment",
|
|
"t5-small",
|
|
"google/pegasus-large",
|
|
]
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--filename",
|
|
required=True,
|
|
type=str,
|
|
help="The filename that we are going to encode in both versions to check that conversion worked",
|
|
)
|
|
parser.add_argument(
|
|
"--models",
|
|
type=lambda s: s.split(","),
|
|
default=pretraineds,
|
|
help=f"The pretrained tokenizers you want to test agains, (default: {pretraineds})",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
print(args.filename)
|
|
|
|
model_len = 50
|
|
status_len = 6
|
|
speedup_len = 8
|
|
print(f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|")
|
|
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
|
|
for pretrained in args.models:
|
|
status, speedup = check(pretrained, args.filename)
|
|
print(f"|{pretrained:<{model_len}}|{status:^{status_len}}|{speedup:^{speedup_len - 1}.2f}x|")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|