from tokenizers import Tokenizer from ..utils import data_dir, doc_wiki_tokenizer disable_printing = True original_print = print def print(*args, **kwargs): if not disable_printing: original_print(*args, **kwargs) class TestQuicktour: # This method contains everything we don't want to run @staticmethod def slow_train(): tokenizer, trainer = TestQuicktour.get_tokenizer_trainer() # START train files = [f"data/wikitext-103-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]] tokenizer.train(files, trainer) # END train # START save tokenizer.save("data/tokenizer-wiki.json") # END save @staticmethod def get_tokenizer_trainer(): # START init_tokenizer from tokenizers import Tokenizer from tokenizers.models import BPE tokenizer = Tokenizer(BPE(unk_token="[UNK]")) # END init_tokenizer # START init_trainer from tokenizers.trainers import BpeTrainer trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]) # END init_trainer # START init_pretok from tokenizers.pre_tokenizers import Whitespace tokenizer.pre_tokenizer = Whitespace() # END init_pretok return tokenizer, trainer def test_quicktour(self, doc_wiki_tokenizer): def print(*args, **kwargs): pass try: # START reload_tokenizer tokenizer = Tokenizer.from_file("data/tokenizer-wiki.json") # END reload_tokenizer except Exception: tokenizer = Tokenizer.from_file(doc_wiki_tokenizer) # START encode output = tokenizer.encode("Hello, y'all! How are you 😁 ?") # END encode # START print_tokens print(output.tokens) # ["Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?"] # END print_tokens assert output.tokens == [ "Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?", ] # START print_ids print(output.ids) # [27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35] # END print_ids assert output.ids == [27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35] # START print_offsets print(output.offsets[9]) # (26, 27) # END print_offsets assert output.offsets[9] == (26, 27) # START use_offsets sentence = "Hello, y'all! How are you 😁 ?" sentence[26:27] # "😁" # END use_offsets assert sentence[26:27] == "😁" # START check_sep tokenizer.token_to_id("[SEP]") # 2 # END check_sep assert tokenizer.token_to_id("[SEP]") == 2 # START init_template_processing from tokenizers.processors import TemplateProcessing tokenizer.post_processor = TemplateProcessing( single="[CLS] $A [SEP]", pair="[CLS] $A [SEP] $B:1 [SEP]:1", special_tokens=[ ("[CLS]", tokenizer.token_to_id("[CLS]")), ("[SEP]", tokenizer.token_to_id("[SEP]")), ], ) # END init_template_processing # START print_special_tokens output = tokenizer.encode("Hello, y'all! How are you 😁 ?") print(output.tokens) # ["[CLS]", "Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?", "[SEP]"] # END print_special_tokens assert output.tokens == [ "[CLS]", "Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?", "[SEP]", ] # START print_special_tokens_pair output = tokenizer.encode("Hello, y'all!", "How are you 😁 ?") print(output.tokens) # ["[CLS]", "Hello", ",", "y", "'", "all", "!", "[SEP]", "How", "are", "you", "[UNK]", "?", "[SEP]"] # END print_special_tokens_pair assert output.tokens == [ "[CLS]", "Hello", ",", "y", "'", "all", "!", "[SEP]", "How", "are", "you", "[UNK]", "?", "[SEP]", ] # START print_type_ids print(output.type_ids) # [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1] # END print_type_ids assert output.type_ids == [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1] # START encode_batch output = tokenizer.encode_batch(["Hello, y'all!", "How are you 😁 ?"]) # END encode_batch # START encode_batch_pair output = tokenizer.encode_batch( [["Hello, y'all!", "How are you 😁 ?"], ["Hello to you too!", "I'm fine, thank you!"]] ) # END encode_batch_pair # START enable_padding tokenizer.enable_padding(pad_id=3, pad_token="[PAD]") # END enable_padding # START print_batch_tokens output = tokenizer.encode_batch(["Hello, y'all!", "How are you 😁 ?"]) print(output[1].tokens) # ["[CLS]", "How", "are", "you", "[UNK]", "?", "[SEP]", "[PAD]"] # END print_batch_tokens assert output[1].tokens == ["[CLS]", "How", "are", "you", "[UNK]", "?", "[SEP]", "[PAD]"] # START print_attention_mask print(output[1].attention_mask) # [1, 1, 1, 1, 1, 1, 1, 0] # END print_attention_mask assert output[1].attention_mask == [1, 1, 1, 1, 1, 1, 1, 0] if __name__ == "__main__": import os from urllib import request from zipfile import ZipFile disable_printing = False if not os.path.isdir("data/wikitext-103-raw"): print("Downloading wikitext-103...") wiki_text, _ = request.urlretrieve( "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip" ) with ZipFile(wiki_text, "r") as z: print("Unzipping in data...") z.extractall("data") print("Now training...") TestQuicktour.slow_train()