Added SentencePiece and YouTokenToMe model extractors.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
Morgan Funtowicz
2020-01-08 22:55:00 +01:00
parent 3b2e19f52c
commit be10f542ce

View File

@ -0,0 +1,133 @@
from argparse import ArgumentParser
from json import dump
from logging import basicConfig, getLogger
from os import linesep, remove
from os.path import exists
from tempfile import NamedTemporaryFile
from typing import Dict, List, Tuple
from requests import get
from sentencepiece import SentencePieceProcessor
from tqdm import trange, tqdm
basicConfig()
logger = getLogger()
class SentencePieceExtractor:
"""
Extractor implementation for SentencePiece trained models.
https://github.com/google/sentencepiece
"""
def __init__(self, model: str):
# Get SentencePiece
self.sp = SentencePieceProcessor()
self.sp.Load(model)
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
sp = self.sp
vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())}
# Merges
merges = []
for piece_l in tqdm(vocab.keys(), total=sp.GetPieceSize()):
for piece_r in vocab.keys():
if piece_l != piece_r:
merge = sp.PieceToId(f"{piece_l}{piece_r}")
score = sp.GetScore(merge)
if score != 0.:
merges += [(piece_l, piece_r)]
return vocab, merges
class YouTokenToMeExtractor:
"""
Extractor implementation for YouTokenToMe trained models format.
Model are as follow:
vocab_size nb_merges
piece piece_id
...(repeated vocab_size)
piece_id_left piece_id_right piece_id
...(repeated nb merges)
"""
def __init__(self, model: str):
self._model = model
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
with open(self._model, 'r') as model_f:
# Retrieve information
nb_pieces, nb_merges = map(int, model_f.readline().split())
vocab, merges = {}, []
# Vocab
for _ in trange(nb_pieces):
piece, piece_id = map(int, model_f.readline().split())
vocab[piece_id] = chr(piece)
# Merges
for _ in trange(nb_merges):
piece_id_l, piece_id_r, piece = map(int, model_f.readline().split())
piece_l, piece_r = vocab[piece_id_l], vocab[piece_id_r]
vocab[piece] = f"{piece_l}{piece_r}"
merges += [(piece_l, piece_r)]
# Special tokens
unk, pad, bos, eos = map(int, model_f.readline().split())
vocab[unk] = '<unk>'
vocab[pad] = '<pad>'
vocab[bos] = '<bos>'
vocab[eos] = '<eos>'
# Invert key and value for vocab
vocab = dict(zip(vocab.values(), vocab.keys()))
return vocab, merges
if __name__ == '__main__':
parser = ArgumentParser('SentencePiece vocab extractor')
parser.add_argument('--provider', type=str, required=True, choices=['sentencepiece', 'youtokentome'],
help='Indicate the format of the file.')
parser.add_argument('--model', type=str, required=True, help='SentencePiece model to extract vocab from.')
parser.add_argument('--vocab-output-path', type=str, required=True,
help='Path where the vocab.json file will be extracted')
parser.add_argument('--merges-output-path', type=str, required=True,
help='Path where the merges file will be extracted')
# Parse cli arguments
args = parser.parse_args()
try:
if args.model.startswith('http'):
# Saving model
with NamedTemporaryFile('wb', delete=False) as f:
logger.info('Writing content from {} to {}'.format(args.model, f.name))
response = get(args.model, allow_redirects=True)
f.write(response.content)
args.remote_model = args.model
args.model = f.name
# Allocate extractor
extractor = SentencePieceExtractor if args.provider == 'sentencepiece' else YouTokenToMeExtractor
extractor = extractor(args.model)
logger.info(f'Using {type(extractor).__name__}')
# Open output files and let's extract model information
with open(args.vocab_output_path, 'w') as vocab_f:
with open(args.merges_output_path, 'w') as merges_f:
# Do the extraction
vocab, merges = extractor.extract()
# Save content
dump(vocab, vocab_f)
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{linesep}", merges))
finally:
# If model was downloaded from internet we need to cleanup the tmp folder.
if hasattr(args, 'remote_model') and exists(args.model):
remove(args.model)