mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
134 lines
4.6 KiB
Python
134 lines
4.6 KiB
Python
from argparse import ArgumentParser
|
|
from json import dump
|
|
from logging import basicConfig, getLogger
|
|
from os import linesep, remove
|
|
from os.path import exists
|
|
from tempfile import NamedTemporaryFile
|
|
from typing import Dict, List, Tuple
|
|
|
|
from requests import get
|
|
from sentencepiece import SentencePieceProcessor
|
|
from tqdm import trange, tqdm
|
|
|
|
basicConfig()
|
|
logger = getLogger()
|
|
|
|
|
|
class SentencePieceExtractor:
|
|
"""
|
|
Extractor implementation for SentencePiece trained models.
|
|
https://github.com/google/sentencepiece
|
|
"""
|
|
|
|
def __init__(self, model: str):
|
|
# Get SentencePiece
|
|
self.sp = SentencePieceProcessor()
|
|
self.sp.Load(model)
|
|
|
|
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
|
|
sp = self.sp
|
|
vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())}
|
|
|
|
# Merges
|
|
merges = []
|
|
for piece_l in tqdm(vocab.keys(), total=sp.GetPieceSize()):
|
|
for piece_r in vocab.keys():
|
|
if piece_l != piece_r:
|
|
merge = sp.PieceToId(f"{piece_l}{piece_r}")
|
|
score = sp.GetScore(merge)
|
|
|
|
if score != 0.:
|
|
merges += [(piece_l, piece_r)]
|
|
|
|
return vocab, merges
|
|
|
|
|
|
class YouTokenToMeExtractor:
|
|
"""
|
|
Extractor implementation for YouTokenToMe trained models format.
|
|
Model are as follow:
|
|
vocab_size nb_merges
|
|
piece piece_id
|
|
...(repeated vocab_size)
|
|
piece_id_left piece_id_right piece_id
|
|
...(repeated nb merges)
|
|
"""
|
|
|
|
def __init__(self, model: str):
|
|
self._model = model
|
|
|
|
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
|
|
with open(self._model, 'r') as model_f:
|
|
|
|
# Retrieve information
|
|
nb_pieces, nb_merges = map(int, model_f.readline().split())
|
|
vocab, merges = {}, []
|
|
|
|
# Vocab
|
|
for _ in trange(nb_pieces):
|
|
piece, piece_id = map(int, model_f.readline().split())
|
|
vocab[piece_id] = chr(piece)
|
|
|
|
# Merges
|
|
for _ in trange(nb_merges):
|
|
piece_id_l, piece_id_r, piece = map(int, model_f.readline().split())
|
|
piece_l, piece_r = vocab[piece_id_l], vocab[piece_id_r]
|
|
vocab[piece] = f"{piece_l}{piece_r}"
|
|
merges += [(piece_l, piece_r)]
|
|
|
|
# Special tokens
|
|
unk, pad, bos, eos = map(int, model_f.readline().split())
|
|
vocab[unk] = '<unk>'
|
|
vocab[pad] = '<pad>'
|
|
vocab[bos] = '<bos>'
|
|
vocab[eos] = '<eos>'
|
|
|
|
# Invert key and value for vocab
|
|
vocab = dict(zip(vocab.values(), vocab.keys()))
|
|
return vocab, merges
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = ArgumentParser('SentencePiece vocab extractor')
|
|
parser.add_argument('--provider', type=str, required=True, choices=['sentencepiece', 'youtokentome'],
|
|
help='Indicate the format of the file.')
|
|
parser.add_argument('--model', type=str, required=True, help='SentencePiece model to extract vocab from.')
|
|
parser.add_argument('--vocab-output-path', type=str, required=True,
|
|
help='Path where the vocab.json file will be extracted')
|
|
parser.add_argument('--merges-output-path', type=str, required=True,
|
|
help='Path where the merges file will be extracted')
|
|
|
|
# Parse cli arguments
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
if args.model.startswith('http'):
|
|
# Saving model
|
|
with NamedTemporaryFile('wb', delete=False) as f:
|
|
logger.info('Writing content from {} to {}'.format(args.model, f.name))
|
|
response = get(args.model, allow_redirects=True)
|
|
f.write(response.content)
|
|
|
|
args.remote_model = args.model
|
|
args.model = f.name
|
|
|
|
# Allocate extractor
|
|
extractor = SentencePieceExtractor if args.provider == 'sentencepiece' else YouTokenToMeExtractor
|
|
extractor = extractor(args.model)
|
|
|
|
logger.info(f'Using {type(extractor).__name__}')
|
|
|
|
# Open output files and let's extract model information
|
|
with open(args.vocab_output_path, 'w') as vocab_f:
|
|
with open(args.merges_output_path, 'w') as merges_f:
|
|
# Do the extraction
|
|
vocab, merges = extractor.extract()
|
|
|
|
# Save content
|
|
dump(vocab, vocab_f)
|
|
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{linesep}", merges))
|
|
finally:
|
|
# If model was downloaded from internet we need to cleanup the tmp folder.
|
|
if hasattr(args, 'remote_model') and exists(args.model):
|
|
remove(args.model)
|