Files
tokenizers/bindings/python/tests/documentation/test_quicktour.py
2020-11-02 17:07:27 -05:00

195 lines
6.3 KiB
Python

from ..utils import data_dir, doc_wiki_tokenizer
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
class TestQuicktour:
# This method contains everything we don't want to run
@staticmethod
def slow_train():
tokenizer, trainer = TestQuicktour.get_tokenizer_trainer()
# START train
files = [f"data/wikitext-103-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]]
tokenizer.train(trainer, files)
# END train
# START reload_model
files = tokenizer.model.save("data", "wiki")
tokenizer.model = BPE.from_files(*files, unk_token="[UNK]")
# END reload_model
# START save
tokenizer.save("data/tokenizer-wiki.json")
# END save
@staticmethod
def get_tokenizer_trainer():
# START init_tokenizer
from tokenizers import Tokenizer
from tokenizers.models import BPE
tokenizer = Tokenizer(BPE())
# END init_tokenizer
# START init_trainer
from tokenizers.trainers import BpeTrainer
trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
# END init_trainer
# START init_pretok
from tokenizers.pre_tokenizers import Whitespace
tokenizer.pre_tokenizer = Whitespace()
# END init_pretok
return tokenizer, trainer
def test_quicktour(self, doc_wiki_tokenizer):
def print(*args, **kwargs):
pass
try:
# START reload_tokenizer
tokenizer = Tokenizer.from_file("data/tokenizer-wiki.json")
# END reload_tokenizer
except Exception:
tokenizer = Tokenizer.from_file(doc_wiki_tokenizer)
# START encode
output = tokenizer.encode("Hello, y'all! How are you 😁 ?")
# END encode
# START print_tokens
print(output.tokens)
# ["Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?"]
# END print_tokens
assert output.tokens == [
"Hello",
",",
"y",
"'",
"all",
"!",
"How",
"are",
"you",
"[UNK]",
"?",
]
# START print_ids
print(output.ids)
# [27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35]
# END print_ids
assert output.ids == [27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35]
# START print_offsets
print(output.offsets[9])
# (26, 27)
# END print_offsets
assert output.offsets[9] == (26, 27)
# START use_offsets
sentence = "Hello, y'all! How are you 😁 ?"
sentence[26:27]
# "😁"
# END use_offsets
assert sentence[26:27] == "😁"
# START check_sep
tokenizer.token_to_id("[SEP]")
# 2
# END check_sep
assert tokenizer.token_to_id("[SEP]") == 2
# START init_template_processing
from tokenizers.processors import TemplateProcessing
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B:1 [SEP]:1",
special_tokens=[
("[CLS]", tokenizer.token_to_id("[CLS]")),
("[SEP]", tokenizer.token_to_id("[SEP]")),
],
)
# END init_template_processing
# START print_special_tokens
output = tokenizer.encode("Hello, y'all! How are you 😁 ?")
print(output.tokens)
# ["[CLS]", "Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?", "[SEP]"]
# END print_special_tokens
assert output.tokens == [
"[CLS]",
"Hello",
",",
"y",
"'",
"all",
"!",
"How",
"are",
"you",
"[UNK]",
"?",
"[SEP]",
]
# START print_special_tokens_pair
output = tokenizer.encode("Hello, y'all!", "How are you 😁 ?")
print(output.tokens)
# ["[CLS]", "Hello", ",", "y", "'", "all", "!", "[SEP]", "How", "are", "you", "[UNK]", "?", "[SEP]"]
# END print_special_tokens_pair
assert output.tokens == [
"[CLS]",
"Hello",
",",
"y",
"'",
"all",
"!",
"[SEP]",
"How",
"are",
"you",
"[UNK]",
"?",
"[SEP]",
]
# START print_type_ids
print(output.type_ids)
# [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
# END print_type_ids
assert output.type_ids == [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
# START encode_batch
output = tokenizer.encode_batch(["Hello, y'all!", "How are you 😁 ?"])
# END encode_batch
# START encode_batch_pair
output = tokenizer.encode_batch(
[["Hello, y'all!", "How are you 😁 ?"], ["Hello to you too!", "I'm fine, thank you!"]]
)
# END encode_batch_pair
# START enable_padding
tokenizer.enable_padding(pad_id=3, pad_token="[PAD]")
# END enable_padding
# START print_batch_tokens
output = tokenizer.encode_batch(["Hello, y'all!", "How are you 😁 ?"])
print(output[1].tokens)
# ["[CLS]", "How", "are", "you", "[UNK]", "?", "[SEP]", "[PAD]"]
# END print_batch_tokens
assert output[1].tokens == ["[CLS]", "How", "are", "you", "[UNK]", "?", "[SEP]", "[PAD]"]
# START print_attention_mask
print(output[1].attention_mask)
# [1, 1, 1, 1, 1, 1, 1, 0]
# END print_attention_mask
assert output[1].attention_mask == [1, 1, 1, 1, 1, 1, 1, 0]
if __name__ == "__main__":
from urllib import request
from zipfile import ZipFile
import os
if not os.path.isdir("data/wikitext-103-raw"):
print("Downloading wikitext-103...")
wiki_text, _ = request.urlretrieve(
"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip"
)
with ZipFile(wiki_text, "r") as z:
print("Unzipping in data...")
z.extractall("data")
print("Now training...")
TestQuicktour.slow_train()