from argparse import ArgumentParser from json import dump from logging import basicConfig, getLogger from os import linesep, remove from os.path import exists from tempfile import NamedTemporaryFile from typing import Dict, List, Tuple from requests import get from sentencepiece import SentencePieceProcessor from tqdm import trange, tqdm basicConfig() logger = getLogger() class SentencePieceExtractor: """ Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece """ def __init__(self, model: str): # Get SentencePiece self.sp = SentencePieceProcessor() self.sp.Load(model) def extract(self) -> Tuple[Dict[str, int], List[Tuple]]: sp = self.sp vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())} # Merges merges = [] for piece_l in tqdm(vocab.keys(), total=sp.GetPieceSize()): for piece_r in vocab.keys(): merge = f"{piece_l}{piece_r}" piece_id = vocab.get(merge, None) if piece_id: merges += [(piece_l, piece_r, piece_id)] merges = sorted(merges, key=lambda val: val[2]) merges = [(val[0], val[1]) for val in merges] return vocab, merges class YouTokenToMeExtractor: """ Extractor implementation for YouTokenToMe trained models format. Model are as follow: vocab_size nb_merges piece piece_id ...(repeated vocab_size) piece_id_left piece_id_right piece_id ...(repeated nb merges) """ def __init__(self, model: str): self._model = model def extract(self) -> Tuple[Dict[str, int], List[Tuple]]: with open(self._model, "r") as model_f: # Retrieve information nb_pieces, nb_merges = map(int, model_f.readline().split()) vocab, merges = {}, [] # Vocab for _ in trange(nb_pieces): piece, piece_id = map(int, model_f.readline().split()) vocab[piece_id] = chr(piece) # Merges for _ in trange(nb_merges): piece_id_l, piece_id_r, piece = map(int, model_f.readline().split()) piece_l, piece_r = vocab[piece_id_l], vocab[piece_id_r] vocab[piece] = f"{piece_l}{piece_r}" merges += [(piece_l, piece_r)] # Special tokens unk, pad, bos, eos = map(int, model_f.readline().split()) vocab[unk] = "" vocab[pad] = "" vocab[bos] = "" vocab[eos] = "" # Invert key and value for vocab vocab = dict(zip(vocab.values(), vocab.keys())) return vocab, merges if __name__ == "__main__": parser = ArgumentParser("SentencePiece vocab extractor") parser.add_argument( "--provider", type=str, required=True, choices=["sentencepiece", "youtokentome"], help="Indicate the format of the file.", ) parser.add_argument( "--model", type=str, required=True, help="SentencePiece model to extract vocab from." ) parser.add_argument( "--vocab-output-path", type=str, required=True, help="Path where the vocab.json file will be extracted", ) parser.add_argument( "--merges-output-path", type=str, required=True, help="Path where the merges file will be extracted", ) # Parse cli arguments args = parser.parse_args() try: if args.model.startswith("http"): # Saving model with NamedTemporaryFile("wb", delete=False) as f: logger.info("Writing content from {} to {}".format(args.model, f.name)) response = get(args.model, allow_redirects=True) f.write(response.content) args.remote_model = args.model args.model = f.name # Allocate extractor extractor = ( SentencePieceExtractor if args.provider == "sentencepiece" else YouTokenToMeExtractor ) extractor = extractor(args.model) logger.info(f"Using {type(extractor).__name__}") # Open output files and let's extract model information with open(args.vocab_output_path, "w") as vocab_f: with open(args.merges_output_path, "w") as merges_f: # Do the extraction vocab, merges = extractor.extract() # Save content dump(vocab, vocab_f) merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{linesep}", merges)) finally: # If model was downloaded from internet we need to cleanup the tmp folder. if hasattr(args, "remote_model") and exists(args.model): remove(args.model)