mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Added SentencePiece and YouTokenToMe model extractors.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
133
bindings/python/scripts/sentencepiece_extractor.py
Normal file
133
bindings/python/scripts/sentencepiece_extractor.py
Normal file
@ -0,0 +1,133 @@
|
||||
from argparse import ArgumentParser
|
||||
from json import dump
|
||||
from logging import basicConfig, getLogger
|
||||
from os import linesep, remove
|
||||
from os.path import exists
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from requests import get
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from tqdm import trange, tqdm
|
||||
|
||||
basicConfig()
|
||||
logger = getLogger()
|
||||
|
||||
|
||||
class SentencePieceExtractor:
|
||||
"""
|
||||
Extractor implementation for SentencePiece trained models.
|
||||
https://github.com/google/sentencepiece
|
||||
"""
|
||||
|
||||
def __init__(self, model: str):
|
||||
# Get SentencePiece
|
||||
self.sp = SentencePieceProcessor()
|
||||
self.sp.Load(model)
|
||||
|
||||
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
|
||||
sp = self.sp
|
||||
vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())}
|
||||
|
||||
# Merges
|
||||
merges = []
|
||||
for piece_l in tqdm(vocab.keys(), total=sp.GetPieceSize()):
|
||||
for piece_r in vocab.keys():
|
||||
if piece_l != piece_r:
|
||||
merge = sp.PieceToId(f"{piece_l}{piece_r}")
|
||||
score = sp.GetScore(merge)
|
||||
|
||||
if score != 0.:
|
||||
merges += [(piece_l, piece_r)]
|
||||
|
||||
return vocab, merges
|
||||
|
||||
|
||||
class YouTokenToMeExtractor:
|
||||
"""
|
||||
Extractor implementation for YouTokenToMe trained models format.
|
||||
Model are as follow:
|
||||
vocab_size nb_merges
|
||||
piece piece_id
|
||||
...(repeated vocab_size)
|
||||
piece_id_left piece_id_right piece_id
|
||||
...(repeated nb merges)
|
||||
"""
|
||||
|
||||
def __init__(self, model: str):
|
||||
self._model = model
|
||||
|
||||
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
|
||||
with open(self._model, 'r') as model_f:
|
||||
|
||||
# Retrieve information
|
||||
nb_pieces, nb_merges = map(int, model_f.readline().split())
|
||||
vocab, merges = {}, []
|
||||
|
||||
# Vocab
|
||||
for _ in trange(nb_pieces):
|
||||
piece, piece_id = map(int, model_f.readline().split())
|
||||
vocab[piece_id] = chr(piece)
|
||||
|
||||
# Merges
|
||||
for _ in trange(nb_merges):
|
||||
piece_id_l, piece_id_r, piece = map(int, model_f.readline().split())
|
||||
piece_l, piece_r = vocab[piece_id_l], vocab[piece_id_r]
|
||||
vocab[piece] = f"{piece_l}{piece_r}"
|
||||
merges += [(piece_l, piece_r)]
|
||||
|
||||
# Special tokens
|
||||
unk, pad, bos, eos = map(int, model_f.readline().split())
|
||||
vocab[unk] = '<unk>'
|
||||
vocab[pad] = '<pad>'
|
||||
vocab[bos] = '<bos>'
|
||||
vocab[eos] = '<eos>'
|
||||
|
||||
# Invert key and value for vocab
|
||||
vocab = dict(zip(vocab.values(), vocab.keys()))
|
||||
return vocab, merges
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser('SentencePiece vocab extractor')
|
||||
parser.add_argument('--provider', type=str, required=True, choices=['sentencepiece', 'youtokentome'],
|
||||
help='Indicate the format of the file.')
|
||||
parser.add_argument('--model', type=str, required=True, help='SentencePiece model to extract vocab from.')
|
||||
parser.add_argument('--vocab-output-path', type=str, required=True,
|
||||
help='Path where the vocab.json file will be extracted')
|
||||
parser.add_argument('--merges-output-path', type=str, required=True,
|
||||
help='Path where the merges file will be extracted')
|
||||
|
||||
# Parse cli arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if args.model.startswith('http'):
|
||||
# Saving model
|
||||
with NamedTemporaryFile('wb', delete=False) as f:
|
||||
logger.info('Writing content from {} to {}'.format(args.model, f.name))
|
||||
response = get(args.model, allow_redirects=True)
|
||||
f.write(response.content)
|
||||
|
||||
args.remote_model = args.model
|
||||
args.model = f.name
|
||||
|
||||
# Allocate extractor
|
||||
extractor = SentencePieceExtractor if args.provider == 'sentencepiece' else YouTokenToMeExtractor
|
||||
extractor = extractor(args.model)
|
||||
|
||||
logger.info(f'Using {type(extractor).__name__}')
|
||||
|
||||
# Open output files and let's extract model information
|
||||
with open(args.vocab_output_path, 'w') as vocab_f:
|
||||
with open(args.merges_output_path, 'w') as merges_f:
|
||||
# Do the extraction
|
||||
vocab, merges = extractor.extract()
|
||||
|
||||
# Save content
|
||||
dump(vocab, vocab_f)
|
||||
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{linesep}", merges))
|
||||
finally:
|
||||
# If model was downloaded from internet we need to cleanup the tmp folder.
|
||||
if hasattr(args, 'remote_model') and exists(args.model):
|
||||
remove(args.model)
|
Reference in New Issue
Block a user