import transformers from tokenizers.implementations import SentencePieceUnigramTokenizer, BaseTokenizer from tokenizers.processors import TemplateProcessing from tokenizers.models import Unigram, BPE from tokenizers import decoders from tokenizers import Tokenizer, Regex from tokenizers.normalizers import ( StripAccents, NFKD, Lowercase, Sequence, BertNormalizer, Precompiled, Replace, ) from tokenizers.pre_tokenizers import ( Digits, WhitespaceSplit, Metaspace, Sequence as PSequence, ) import json import unicodedata import sys import os import datetime import argparse sys.path.append(".") from spm_parity_check import check_details from sentencepiece_extractor import SentencePieceExtractor def check_number_comma(piece: str) -> bool: return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit() def get_proto(filename: str): try: import sys sys.path.append(".") import sentencepiece_model_pb2 as model except Exception: raise Exception( "You don't seem to have the required protobuf file, in order to use this function you need to run `pip install protobuf` and `wget https://raw.githubusercontent.com/google/sentencepiece/master/python/sentencepiece_model_pb2.py` for us to be able to read the intrinsics of your spm_file. `pip install sentencepiece` is not required." ) m = model.ModelProto() m.ParseFromString(open(filename, "rb").read()) return m class Converter: def __init__(self, original_tokenizer): self.original_tokenizer = original_tokenizer def converted(self) -> Tokenizer: raise NotImplementedError() class SpmConverter(Converter): def __init__(self, *args): super().__init__(*args) self.proto = get_proto(self.original_tokenizer.vocab_file) def vocab(self, proto): return [(piece.piece, piece.score) for piece in proto.pieces] def unk_id(self, proto): return proto.trainer_spec.unk_id def tokenizer(self, proto): model_type = proto.trainer_spec.model_type vocab = self.vocab(proto) unk_id = self.unk_id(proto) if model_type == 1: tokenizer = Tokenizer(Unigram(vocab, unk_id)) elif model_type == 2: vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract() tokenizer = Tokenizer(BPE(vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True)) else: raise Exception( "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" ) return tokenizer def normalizer(self, proto): precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap return Sequence([Precompiled(precompiled_charsmap), Replace(Regex(" {2,}"), " ")]) def post_processor(self, tokenizer): return None def converted(self): tokenizer = self.tokenizer(self.proto) # Tokenizer assemble tokenizer.normalizer = self.normalizer(self.proto) replacement = "▁" prepend_scheme = "always" tokenizer.pre_tokenizer = Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) post_processor = self.post_processor(tokenizer) if post_processor: tokenizer.post_processor = post_processor # TODO what parameters should we give ? parameters = {} return BaseTokenizer(tokenizer, parameters) class AlbertConverter(SpmConverter): def vocab(self, proto): return [ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100) for piece in proto.pieces ] def normalizer(self, proto): normalizers = [Replace("``", '"'), Replace("''", '"')] if not self.original_tokenizer.keep_accents: normalizers.append(NFKD()) normalizers.append(StripAccents()) if self.original_tokenizer.do_lower_case: normalizers.append(Lowercase()) precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap normalizers.append(Precompiled(precompiled_charsmap)) normalizers.append(Replace(Regex(" {2,}"), " ")) return Sequence(normalizers) def post_processor(self, tokenizer): return TemplateProcessing( seq_a=["[CLS]", "$0", "[SEP]"], seq_b=["$1", "[SEP]"], special_tokens=[ ("[CLS]", tokenizer.get_vocab()["[CLS]"]), ("[SEP]", tokenizer.get_vocab()["[SEP]"]), ], ) class CamembertConverter(SpmConverter): def vocab(self, proto): vocab = [ ("NOTUSED", 0.0), ("", 0.0), ("NOTUSED", 0.0), ("", 0.0), ] vocab += [(piece.piece, piece.score) for piece in proto.pieces] return vocab def unk_id(self, proto): # See vocab unk position return 3 def post_processor(self, tokenizer): return TemplateProcessing( seq_a=["", "$0", ""], seq_b=["$1", ""], special_tokens=[ ("", tokenizer.get_vocab()[""]), ("", tokenizer.get_vocab()[""]), ], ) class MBartConverter(SpmConverter): def vocab(self, proto): vocab = [ ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ] vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] vocab += [ ("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ] return vocab def unk_id(self, proto): return 3 def post_processor(self, tokenizer): return TemplateProcessing( seq_a=["$0", "", "en_XX"], seq_b=["$1", ""], special_tokens=[ ("en_XX", tokenizer.get_vocab()["en_XX"]), ("", tokenizer.get_vocab()[""]), ], ) class XLMRobertaConverter(SpmConverter): def vocab(self, proto): vocab = [ ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ] vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] return vocab def unk_id(self, proto): unk_id = 3 return unk_id def post_processor(self, tokenizer): return TemplateProcessing( seq_a=["", "$0", ""], seq_b=["$1", ""], special_tokens=[ ("", tokenizer.get_vocab()[""]), ("", tokenizer.get_vocab()[""]), ], ) class XLNetConverter(SpmConverter): def vocab(self, proto): return [ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100) for piece in proto.pieces ] def normalizer(self, proto): normalizers = [Replace("``", '"'), Replace("''", '"')] if not self.original_tokenizer.keep_accents: normalizers.append(NFKD()) normalizers.append(StripAccents()) if self.original_tokenizer.do_lower_case: normalizers.append(Lowercase()) precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap normalizers.append(Precompiled(precompiled_charsmap)) normalizers.append(Replace(Regex(" {2,}"), " ")) return Sequence(normalizers) def post_processor(self, tokenizer): return TemplateProcessing( seq_a=["$0", "", ""], seq_b=["$1", ""], special_tokens=[ ("", tokenizer.get_vocab()[""]), ("", tokenizer.get_vocab()[""]), ], ) class ReformerConverter(SpmConverter): pass class PegasusConverter(SpmConverter): offset = 103 def vocab(self, proto): vocab = [ (self.original_tokenizer.pad_token, 0), (self.original_tokenizer.eos_token, 0), ] vocab += [(f"unk_{i}", -100) for i in range(2, 2 + self.offset)] vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]] return vocab def unk_id(self, proto): return proto.trainer_spec.unk_id + self.offset def post_processor(self, tokenizer): eos = self.original_tokenizer.eos_token return TemplateProcessing( seq_a=["$0", eos], seq_b=["$1", eos], special_tokens=[(eos, tokenizer.get_vocab()[eos])], ) class T5Converter(SpmConverter): def post_processor(self, tokenizer): return TemplateProcessing( seq_a=["$0", ""], seq_b=["$1", ""], special_tokens=[("", tokenizer.get_vocab()[""])], ) CONVERTERS = { "AlbertTokenizer": AlbertConverter, "CamembertTokenizer": CamembertConverter, "XLMRobertaTokenizer": XLMRobertaConverter, "MBartTokenizer": MBartConverter, "XLNetTokenizer": XLNetConverter, "ReformerTokenizer": ReformerConverter, "PegasusTokenizer": PegasusConverter, "T5Tokenizer": T5Converter, } def check(pretrained, filename): transformer_tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained) converter_class = CONVERTERS[transformer_tokenizer.__class__.__name__] tokenizer = converter_class(transformer_tokenizer).converted() now = datetime.datetime.now trans_total_time = datetime.timedelta(seconds=0) tok_total_time = datetime.timedelta(seconds=0) with open(filename, "r") as f: for i, line in enumerate(f): line = line.strip() start = now() ids = transformer_tokenizer.encode(line) trans = now() tok_ids = tokenizer.encode(line).ids tok = now() trans_total_time += trans - start tok_total_time += tok - trans if ids != tok_ids: if check_details(line, ids, tok_ids, transformer_tokenizer, tokenizer): continue assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}" tokenizer.save(f"{pretrained.replace('/', '-')}.json") return ("OK", trans_total_time / tok_total_time) def main(): pretraineds = [ "albert-base-v1", "albert-large-v1", "albert-xlarge-v1", "albert-xxlarge-v1", "albert-base-v2", "albert-large-v2", "albert-xlarge-v2", "albert-xxlarge-v2", "camembert-base", "xlm-roberta-base", "xlm-roberta-large", "xlm-roberta-large-finetuned-conll02-dutch", "xlm-roberta-large-finetuned-conll02-spanish", "xlm-roberta-large-finetuned-conll03-english", "xlm-roberta-large-finetuned-conll03-german", "facebook/mbart-large-en-ro", "facebook/mbart-large-cc25", "xlnet-base-cased", "xlnet-large-cased", "google/reformer-crime-and-punishment", "t5-small", "google/pegasus-large", ] parser = argparse.ArgumentParser() parser.add_argument( "--filename", required=True, type=str, help="The filename that we are going to encode in both versions to check that conversion worked", ) parser.add_argument( "--models", type=lambda s: s.split(","), default=pretraineds, help=f"The pretrained tokenizers you want to test against, (default: {pretraineds})", ) args = parser.parse_args() print(args.filename) model_len = 50 status_len = 6 speedup_len = 8 print(f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|") print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|") for pretrained in args.models: status, speedup = check(pretrained, args.filename) print(f"|{pretrained:<{model_len}}|{status:^{status_len}}|{speedup:^{speedup_len - 1}.2f}x|") if __name__ == "__main__": main()