Generate pyi, fix tests and clippy warnings

This commit is contained in:
Anthony MOI
2020-11-19 17:57:58 -05:00
committed by Anthony MOI
parent 5059be1a8d
commit 387b8a1033
7 changed files with 56 additions and 74 deletions

View File

@@ -2,8 +2,13 @@ from ..utils import data_dir, doc_wiki_tokenizer, doc_pipeline_bert_tokenizer
from tokenizers import Tokenizer
disable_printing = True
original_print = print
def print(*args, **kwargs):
pass
if not disable_printing:
original_print(*args, **kwargs)
class TestPipeline:
@@ -103,7 +108,7 @@ class TestPipeline:
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
bert_tokenizer = Tokenizer(WordPiece())
bert_tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
# END bert_setup_tokenizer
# START bert_setup_normalizer
from tokenizers import normalizers
@@ -135,10 +140,7 @@ class TestPipeline:
vocab_size=30522, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
files = [f"data/wikitext-103-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]]
bert_tokenizer.train(trainer, files)
model_files = bert_tokenizer.model.save("data", "bert-wiki")
bert_tokenizer.model = WordPiece.from_file(*model_files, unk_token="[UNK]")
bert_tokenizer.train(files, trainer)
bert_tokenizer.save("data/bert-wiki.json")
# END bert_train_tokenizer
@@ -173,6 +175,7 @@ if __name__ == "__main__":
from zipfile import ZipFile
import os
disable_printing = False
if not os.path.isdir("data/wikitext-103-raw"):
print("Downloading wikitext-103...")
wiki_text, _ = request.urlretrieve(

View File

@@ -4,6 +4,14 @@ from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
disable_printing = True
original_print = print
def print(*args, **kwargs):
if not disable_printing:
original_print(*args, **kwargs)
class TestQuicktour:
# This method contains everything we don't want to run
@@ -13,12 +21,8 @@ class TestQuicktour:
# START train
files = [f"data/wikitext-103-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]]
tokenizer.train(trainer, files)
tokenizer.train(files, trainer)
# END train
# START reload_model
files = tokenizer.model.save("data", "wiki")
tokenizer.model = BPE.from_file(*files, unk_token="[UNK]")
# END reload_model
# START save
tokenizer.save("data/tokenizer-wiki.json")
# END save
@@ -29,7 +33,7 @@ class TestQuicktour:
from tokenizers import Tokenizer
from tokenizers.models import BPE
tokenizer = Tokenizer(BPE())
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
# END init_tokenizer
# START init_trainer
from tokenizers.trainers import BpeTrainer
@@ -181,6 +185,7 @@ if __name__ == "__main__":
from zipfile import ZipFile
import os
disable_printing = False
if not os.path.isdir("data/wikitext-103-raw"):
print("Downloading wikitext-103...")
wiki_text, _ = request.urlretrieve(