mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Black pass.
This commit is contained in:
committed by
Anthony MOI
parent
7ed7f0f26a
commit
6887c0f04d
@ -58,7 +58,7 @@ class YouTokenToMeExtractor:
|
||||
self._model = model
|
||||
|
||||
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
|
||||
with open(self._model, 'r') as model_f:
|
||||
with open(self._model, "r") as model_f:
|
||||
|
||||
# Retrieve information
|
||||
nb_pieces, nb_merges = map(int, model_f.readline().split())
|
||||
@ -78,34 +78,49 @@ class YouTokenToMeExtractor:
|
||||
|
||||
# Special tokens
|
||||
unk, pad, bos, eos = map(int, model_f.readline().split())
|
||||
vocab[unk] = '<unk>'
|
||||
vocab[pad] = '<pad>'
|
||||
vocab[bos] = '<bos>'
|
||||
vocab[eos] = '<eos>'
|
||||
vocab[unk] = "<unk>"
|
||||
vocab[pad] = "<pad>"
|
||||
vocab[bos] = "<bos>"
|
||||
vocab[eos] = "<eos>"
|
||||
|
||||
# Invert key and value for vocab
|
||||
vocab = dict(zip(vocab.values(), vocab.keys()))
|
||||
return vocab, merges
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser('SentencePiece vocab extractor')
|
||||
parser.add_argument('--provider', type=str, required=True, choices=['sentencepiece', 'youtokentome'],
|
||||
help='Indicate the format of the file.')
|
||||
parser.add_argument('--model', type=str, required=True, help='SentencePiece model to extract vocab from.')
|
||||
parser.add_argument('--vocab-output-path', type=str, required=True,
|
||||
help='Path where the vocab.json file will be extracted')
|
||||
parser.add_argument('--merges-output-path', type=str, required=True,
|
||||
help='Path where the merges file will be extracted')
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser("SentencePiece vocab extractor")
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["sentencepiece", "youtokentome"],
|
||||
help="Indicate the format of the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, required=True, help="SentencePiece model to extract vocab from."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab-output-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path where the vocab.json file will be extracted",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merges-output-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path where the merges file will be extracted",
|
||||
)
|
||||
|
||||
# Parse cli arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if args.model.startswith('http'):
|
||||
if args.model.startswith("http"):
|
||||
# Saving model
|
||||
with NamedTemporaryFile('wb', delete=False) as f:
|
||||
logger.info('Writing content from {} to {}'.format(args.model, f.name))
|
||||
with NamedTemporaryFile("wb", delete=False) as f:
|
||||
logger.info("Writing content from {} to {}".format(args.model, f.name))
|
||||
response = get(args.model, allow_redirects=True)
|
||||
f.write(response.content)
|
||||
|
||||
@ -113,14 +128,16 @@ if __name__ == '__main__':
|
||||
args.model = f.name
|
||||
|
||||
# Allocate extractor
|
||||
extractor = SentencePieceExtractor if args.provider == 'sentencepiece' else YouTokenToMeExtractor
|
||||
extractor = (
|
||||
SentencePieceExtractor if args.provider == "sentencepiece" else YouTokenToMeExtractor
|
||||
)
|
||||
extractor = extractor(args.model)
|
||||
|
||||
logger.info(f'Using {type(extractor).__name__}')
|
||||
logger.info(f"Using {type(extractor).__name__}")
|
||||
|
||||
# Open output files and let's extract model information
|
||||
with open(args.vocab_output_path, 'w') as vocab_f:
|
||||
with open(args.merges_output_path, 'w') as merges_f:
|
||||
with open(args.vocab_output_path, "w") as vocab_f:
|
||||
with open(args.merges_output_path, "w") as merges_f:
|
||||
# Do the extraction
|
||||
vocab, merges = extractor.extract()
|
||||
|
||||
@ -129,5 +146,5 @@ if __name__ == '__main__':
|
||||
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{linesep}", merges))
|
||||
finally:
|
||||
# If model was downloaded from internet we need to cleanup the tmp folder.
|
||||
if hasattr(args, 'remote_model') and exists(args.model):
|
||||
if hasattr(args, "remote_model") and exists(args.model):
|
||||
remove(args.model)
|
||||
|
Reference in New Issue
Block a user