mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
* nits * Fixing deps. * Ruff update. * Import order matters. * Fix. * Revert ruff fix. * Visualizer. * Putting back the imports. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
146 lines
4.7 KiB
Python
146 lines
4.7 KiB
Python
from argparse import ArgumentParser
|
|
from json import dump
|
|
from logging import basicConfig, getLogger
|
|
from os import linesep, remove
|
|
from os.path import exists
|
|
from tempfile import NamedTemporaryFile
|
|
from typing import Dict, List, Tuple
|
|
|
|
from requests import get
|
|
from sentencepiece import SentencePieceProcessor
|
|
from tqdm import trange, tqdm
|
|
|
|
basicConfig()
|
|
logger = getLogger()
|
|
|
|
|
|
class SentencePieceExtractor:
|
|
"""
|
|
Extractor implementation for SentencePiece trained models.
|
|
https://github.com/google/sentencepiece
|
|
"""
|
|
|
|
def __init__(self, model: str):
|
|
# Get SentencePiece
|
|
self.sp = SentencePieceProcessor()
|
|
self.sp.Load(model)
|
|
|
|
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
|
|
sp = self.sp
|
|
vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())}
|
|
|
|
# Merges
|
|
merges = []
|
|
for piece_l in tqdm(vocab.keys(), total=sp.GetPieceSize()):
|
|
for piece_r in vocab.keys():
|
|
merge = f"{piece_l}{piece_r}"
|
|
piece_id = vocab.get(merge, None)
|
|
if piece_id:
|
|
merges += [(piece_l, piece_r, piece_id)]
|
|
merges = sorted(merges, key=lambda val: val[2])
|
|
merges = [(val[0], val[1]) for val in merges]
|
|
|
|
return vocab, merges
|
|
|
|
|
|
class YouTokenToMeExtractor:
|
|
"""
|
|
Extractor implementation for YouTokenToMe trained models format.
|
|
Model are as follow:
|
|
vocab_size nb_merges
|
|
piece piece_id
|
|
...(repeated vocab_size)
|
|
piece_id_left piece_id_right piece_id
|
|
...(repeated nb merges)
|
|
"""
|
|
|
|
def __init__(self, model: str):
|
|
self._model = model
|
|
|
|
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
|
|
with open(self._model, "r") as model_f:
|
|
# Retrieve information
|
|
nb_pieces, nb_merges = map(int, model_f.readline().split())
|
|
vocab, merges = {}, []
|
|
|
|
# Vocab
|
|
for _ in trange(nb_pieces):
|
|
piece, piece_id = map(int, model_f.readline().split())
|
|
vocab[piece_id] = chr(piece)
|
|
|
|
# Merges
|
|
for _ in trange(nb_merges):
|
|
piece_id_l, piece_id_r, piece = map(int, model_f.readline().split())
|
|
piece_l, piece_r = vocab[piece_id_l], vocab[piece_id_r]
|
|
vocab[piece] = f"{piece_l}{piece_r}"
|
|
merges += [(piece_l, piece_r)]
|
|
|
|
# Special tokens
|
|
unk, pad, bos, eos = map(int, model_f.readline().split())
|
|
vocab[unk] = "<unk>"
|
|
vocab[pad] = "<pad>"
|
|
vocab[bos] = "<bos>"
|
|
vocab[eos] = "<eos>"
|
|
|
|
# Invert key and value for vocab
|
|
vocab = dict(zip(vocab.values(), vocab.keys()))
|
|
return vocab, merges
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser("SentencePiece vocab extractor")
|
|
parser.add_argument(
|
|
"--provider",
|
|
type=str,
|
|
required=True,
|
|
choices=["sentencepiece", "youtokentome"],
|
|
help="Indicate the format of the file.",
|
|
)
|
|
parser.add_argument("--model", type=str, required=True, help="SentencePiece model to extract vocab from.")
|
|
parser.add_argument(
|
|
"--vocab-output-path",
|
|
type=str,
|
|
required=True,
|
|
help="Path where the vocab.json file will be extracted",
|
|
)
|
|
parser.add_argument(
|
|
"--merges-output-path",
|
|
type=str,
|
|
required=True,
|
|
help="Path where the merges file will be extracted",
|
|
)
|
|
|
|
# Parse cli arguments
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
if args.model.startswith("http"):
|
|
# Saving model
|
|
with NamedTemporaryFile("wb", delete=False) as f:
|
|
logger.info("Writing content from {} to {}".format(args.model, f.name))
|
|
response = get(args.model, allow_redirects=True)
|
|
f.write(response.content)
|
|
|
|
args.remote_model = args.model
|
|
args.model = f.name
|
|
|
|
# Allocate extractor
|
|
extractor = SentencePieceExtractor if args.provider == "sentencepiece" else YouTokenToMeExtractor
|
|
extractor = extractor(args.model)
|
|
|
|
logger.info(f"Using {type(extractor).__name__}")
|
|
|
|
# Open output files and let's extract model information
|
|
with open(args.vocab_output_path, "w") as vocab_f:
|
|
with open(args.merges_output_path, "w") as merges_f:
|
|
# Do the extraction
|
|
vocab, merges = extractor.extract()
|
|
|
|
# Save content
|
|
dump(vocab, vocab_f)
|
|
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{linesep}", merges))
|
|
finally:
|
|
# If model was downloaded from internet we need to cleanup the tmp folder.
|
|
if hasattr(args, "remote_model") and exists(args.model):
|
|
remove(args.model)
|