Black pass.

This commit is contained in:
Nicolas Patry
2020-08-21 16:49:41 +02:00
committed by Anthony MOI
parent 7ed7f0f26a
commit 6887c0f04d
3 changed files with 42 additions and 32 deletions

View File

@ -58,7 +58,7 @@ class YouTokenToMeExtractor:
self._model = model
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
with open(self._model, 'r') as model_f:
with open(self._model, "r") as model_f:
# Retrieve information
nb_pieces, nb_merges = map(int, model_f.readline().split())
@ -78,34 +78,49 @@ class YouTokenToMeExtractor:
# Special tokens
unk, pad, bos, eos = map(int, model_f.readline().split())
vocab[unk] = '<unk>'
vocab[pad] = '<pad>'
vocab[bos] = '<bos>'
vocab[eos] = '<eos>'
vocab[unk] = "<unk>"
vocab[pad] = "<pad>"
vocab[bos] = "<bos>"
vocab[eos] = "<eos>"
# Invert key and value for vocab
vocab = dict(zip(vocab.values(), vocab.keys()))
return vocab, merges
if __name__ == '__main__':
parser = ArgumentParser('SentencePiece vocab extractor')
parser.add_argument('--provider', type=str, required=True, choices=['sentencepiece', 'youtokentome'],
help='Indicate the format of the file.')
parser.add_argument('--model', type=str, required=True, help='SentencePiece model to extract vocab from.')
parser.add_argument('--vocab-output-path', type=str, required=True,
help='Path where the vocab.json file will be extracted')
parser.add_argument('--merges-output-path', type=str, required=True,
help='Path where the merges file will be extracted')
if __name__ == "__main__":
parser = ArgumentParser("SentencePiece vocab extractor")
parser.add_argument(
"--provider",
type=str,
required=True,
choices=["sentencepiece", "youtokentome"],
help="Indicate the format of the file.",
)
parser.add_argument(
"--model", type=str, required=True, help="SentencePiece model to extract vocab from."
)
parser.add_argument(
"--vocab-output-path",
type=str,
required=True,
help="Path where the vocab.json file will be extracted",
)
parser.add_argument(
"--merges-output-path",
type=str,
required=True,
help="Path where the merges file will be extracted",
)
# Parse cli arguments
args = parser.parse_args()
try:
if args.model.startswith('http'):
if args.model.startswith("http"):
# Saving model
with NamedTemporaryFile('wb', delete=False) as f:
logger.info('Writing content from {} to {}'.format(args.model, f.name))
with NamedTemporaryFile("wb", delete=False) as f:
logger.info("Writing content from {} to {}".format(args.model, f.name))
response = get(args.model, allow_redirects=True)
f.write(response.content)
@ -113,14 +128,16 @@ if __name__ == '__main__':
args.model = f.name
# Allocate extractor
extractor = SentencePieceExtractor if args.provider == 'sentencepiece' else YouTokenToMeExtractor
extractor = (
SentencePieceExtractor if args.provider == "sentencepiece" else YouTokenToMeExtractor
)
extractor = extractor(args.model)
logger.info(f'Using {type(extractor).__name__}')
logger.info(f"Using {type(extractor).__name__}")
# Open output files and let's extract model information
with open(args.vocab_output_path, 'w') as vocab_f:
with open(args.merges_output_path, 'w') as merges_f:
with open(args.vocab_output_path, "w") as vocab_f:
with open(args.merges_output_path, "w") as merges_f:
# Do the extraction
vocab, merges = extractor.extract()
@ -129,5 +146,5 @@ if __name__ == '__main__':
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{linesep}", merges))
finally:
# If model was downloaded from internet we need to cleanup the tmp folder.
if hasattr(args, 'remote_model') and exists(args.model):
if hasattr(args, "remote_model") and exists(args.model):
remove(args.model)