diff --git a/bindings/node/Makefile b/bindings/node/Makefile index 60b23c7e..3f09b78a 100644 --- a/bindings/node/Makefile +++ b/bindings/node/Makefile @@ -12,7 +12,7 @@ style: check-style: npm run lint-check -TESTS_RESOURCES = $(DATA_DIR)/small.txt $(DATA_DIR)/roberta.json $(DATA_DIR)/tokenizer-wiki.json +TESTS_RESOURCES = $(DATA_DIR)/small.txt $(DATA_DIR)/roberta.json $(DATA_DIR)/tokenizer-wiki.json $(DATA_DIR)/bert-wiki.json # Launch the test suite test: $(TESTS_RESOURCES) @@ -32,3 +32,7 @@ $(DATA_DIR)/roberta.json : $(DATA_DIR)/tokenizer-wiki.json : $(dir_guard) wget https://s3.amazonaws.com/models.huggingface.co/bert/anthony/doc-quicktour/tokenizer.json -O $@ + +$(DATA_DIR)/bert-wiki.json : + $(dir_guard) + wget https://s3.amazonaws.com/models.huggingface.co/bert/anthony/doc-pipeline/tokenizer.json -O $@ diff --git a/bindings/node/examples/documentation/pipeline.test.ts b/bindings/node/examples/documentation/pipeline.test.ts index a81323af..c7d963e3 100644 --- a/bindings/node/examples/documentation/pipeline.test.ts +++ b/bindings/node/examples/documentation/pipeline.test.ts @@ -4,9 +4,16 @@ var globRequire = require; describe("pipelineExample", () => { // This is a hack to let us require using path similar to what the user has to use function require(mod: string) { - let path = mod.slice("tokenizers/".length); - return globRequire("../../lib/" + path); + if (mod.startsWith("tokenizers/")) { + let path = mod.slice("tokenizers/".length); + return globRequire("../../lib/" + path); + } else { + return globRequire(mod); + } } + let console = { + log: (..._args: any[]) => {} + }; it("shows pipeline parts", async () => { // START reload_tokenizer @@ -20,7 +27,7 @@ describe("pipelineExample", () => { let normalizer = sequenceNormalizer([nfdNormalizer(), stripAccentsNormalizer()]); // END setup_normalizer // START test_normalizer - let normalized = normalizer.normalizeStr("Héllò hôw are ü?") + let normalized = normalizer.normalizeString("Héllò hôw are ü?") // "Hello how are u?" // END test_normalizer expect(normalized).toEqual("Hello how are u?"); @@ -28,10 +35,10 @@ describe("pipelineExample", () => { tokenizer.setNormalizer(normalizer) // END replace_normalizer // START setup_pre_tokenizer - let { whitespacePreTokenizer } = require("tokenizers/bindings/pre_tokenizers"); + let { whitespacePreTokenizer } = require("tokenizers/bindings/pre-tokenizers"); var preTokenizer = whitespacePreTokenizer(); - var preTokenized = preTokenizer.preTokenizeStr("Hello! How are you? I'm fine, thank you."); + var preTokenized = preTokenizer.preTokenizeString("Hello! How are you? I'm fine, thank you."); // END setup_pre_tokenizer expect(preTokenized).toEqual([ ["Hello", [0, 5]], @@ -50,16 +57,16 @@ describe("pipelineExample", () => { [".", [39, 40]] ]); // START combine_pre_tokenizer - let { sequencePreTokenizer, digitsPreTokenizer } = require("tokenizers/bindings/pre_tokenizers"); + let { sequencePreTokenizer, digitsPreTokenizer } = require("tokenizers/bindings/pre-tokenizers"); var preTokenizer = sequencePreTokenizer([whitespacePreTokenizer(), digitsPreTokenizer(true)]); - var preTokenized = preTokenizer.preTokenizeStr("Call 911!"); + var preTokenized = preTokenizer.preTokenizeString("Call 911!"); // END combine_pre_tokenizer // START replace_pre_tokenizer tokenizer.setPreTokenizer(preTokenizer) // END replace_pre_tokenizer // START setup_processor - let { templateProcessing } = require("tokenizers/bindings/processors"); + let { templateProcessing } = require("tokenizers/bindings/post-processors"); tokenizer.setPostProcessor(templateProcessing( "[CLS] $A [SEP]", @@ -68,42 +75,21 @@ describe("pipelineExample", () => { )); // END setup_processor // START test_decoding - let output = tokenizer.encode("Hello, y'all! How are you 😁 ?"); + let { promisify } = require('util'); + let encode = promisify(tokenizer.encode.bind(tokenizer)); + let decode = promisify(tokenizer.decode.bind(tokenizer)); + + let output = await encode("Hello, y'all! How are you 😁 ?"); console.log(output.getIds()); // [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2] - tokenizer.decode([1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2]); + let decoded = await decode([1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2], true); // "Hello , y ' all ! How are you ?" // END test_decoding + expect(decoded).toEqual("Hello , y ' all ! How are you ?"); }); - var { Tokenizer } = require("tokenizers/bindings/tokenizer"); - const slow_bert_training = async (bertTokenizer: typeof Tokenizer) => { - let { WordPiece } = require("tokenizers/bindings/models"); - - // START bert_train_tokenizer - let { wordPieceTrainer } = require("tokenizers/bindings/trainers"); - let { promisify } = require("util"); - - let trainer = wordPieceTrainer({ - vocabSize: 30522, - specialTokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] - }); - let files = ["test", "train", "valid"].map(split => `data/wikitext-103-raw/wiki.${split}.raw`); - bertTokenizer.train(trainer, files); - - let modelFiles = bertTokenizer.getModel.save("data", "bert-wiki"); - let fromFile = promisify(WordPiece.fromFile); - bertTokenizer.setModel(await fromFile(modelFiles[0], { - unkToken: "[UNK]" - })); - - bertTokenizer.save("data/bert-wiki.json") - // END bert_train_tokenizer - }; - console.log(slow_bert_training); // disable unused warning - - it("shows a full bert example", async () => { + it.skip("trains the tokenizer", async () => { // START bert_setup_tokenizer let { Tokenizer } = require("tokenizers/bindings/tokenizer"); let { WordPiece } = require("tokenizers/bindings/models"); @@ -121,7 +107,7 @@ describe("pipelineExample", () => { // START bert_setup_pre_tokenizer let { whitespacePreTokenizer } = require("tokenizers/bindings/pre-tokenizers"); - bertTokenizer.setPreTokenizer = whitespacePreTokenizer(); + bertTokenizer.setPreTokenizer(whitespacePreTokenizer()); // END bert_setup_pre_tokenizer // START bert_setup_processor let { templateProcessing } = require("tokenizers/bindings/post-processors"); @@ -132,19 +118,50 @@ describe("pipelineExample", () => { [["[CLS]", 1], ["[SEP]", 2]] )); // END bert_setup_processor + // START bert_train_tokenizer + let { wordPieceTrainer } = require("tokenizers/bindings/trainers"); + let { promisify } = require("util"); + + let trainer = wordPieceTrainer({ + vocabSize: 30522, + specialTokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] + }); + let files = ["test", "train", "valid"].map(split => `data/wikitext-103-raw/wiki.${split}.raw`); + bertTokenizer.train(trainer, files); + + let modelFiles = bertTokenizer.getModel().save("data", "bert-wiki"); + let fromFile = promisify(WordPiece.fromFile); + bertTokenizer.setModel(await fromFile(modelFiles[0], { + unkToken: "[UNK]" + })); + + bertTokenizer.save("data/bert-wiki.json") + // END bert_train_tokenizer + }); + + it("shows a full bert example", async () => { + let { Tokenizer } = require("tokenizers/bindings/tokenizer"); + let bertTokenizer = await Tokenizer.fromFile("data/bert-wiki.json") + // START bert_test_decoding - let output = bertTokenizer.encode("Welcome to the 🤗 Tokenizers library."); + let { promisify } = require("util"); + let encode = promisify(bertTokenizer.encode.bind(bertTokenizer)); + let decode = promisify(bertTokenizer.decode.bind(bertTokenizer)); + + let output = await encode("Welcome to the 🤗 Tokenizers library."); console.log(output.getTokens()); // ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"] - bertTokenizer.decode(output.getIds()); + var decoded = await decode(output.getIds(), true); // "welcome to the tok ##eni ##zer ##s library ." // END bert_test_decoding + expect(decoded).toEqual("welcome to the tok ##eni ##zer ##s library ."); // START bert_proper_decoding let { wordPieceDecoder } = require("tokenizers/bindings/decoders"); bertTokenizer.setDecoder(wordPieceDecoder()); - bertTokenizer.decode(output.ids); + var decoded = await decode(output.getIds(), true); // "welcome to the tokenizers library." // END bert_proper_decoding + expect(decoded).toEqual("welcome to the tokenizers library."); }); }); diff --git a/bindings/python/tests/documentation/test_pipeline.py b/bindings/python/tests/documentation/test_pipeline.py index 874c740e..f1f8cbca 100644 --- a/bindings/python/tests/documentation/test_pipeline.py +++ b/bindings/python/tests/documentation/test_pipeline.py @@ -1,4 +1,4 @@ -from ..utils import data_dir, doc_wiki_tokenizer +from ..utils import data_dir, doc_wiki_tokenizer, doc_pipeline_bert_tokenizer from tokenizers import Tokenizer @@ -96,7 +96,8 @@ class TestPipeline: == "Hello , y ' all ! How are you ?" ) - def bert_example(self): + @staticmethod + def slow_train(): # START bert_setup_tokenizer from tokenizers import Tokenizer from tokenizers.models import WordPiece @@ -136,20 +137,49 @@ class TestPipeline: bert_tokenizer.train(trainer, files) model_files = bert_tokenizer.model.save("data", "bert-wiki") - bert_tokenizer.model = WordPiece(*model_files, unk_token="[UNK]") + bert_tokenizer.model = WordPiece.from_file(*model_files, unk_token="[UNK]") bert_tokenizer.save("data/bert-wiki.json") # END bert_train_tokenizer + + def test_bert_example(self): + try: + bert_tokenizer = Tokenizer.from_file("data/bert-wiki.json") + except Exception: + bert_tokenizer = Tokenizer.from_file(doc_pipeline_bert_tokenizer) + # START bert_test_decoding output = bert_tokenizer.encode("Welcome to the 🤗 Tokenizers library.") print(output.tokens) # ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"] - bert_tokenizer.decoder(output.ids) + bert_tokenizer.decode(output.ids) # "welcome to the tok ##eni ##zer ##s library ." # END bert_test_decoding + assert bert_tokenizer.decode(output.ids) == "welcome to the tok ##eni ##zer ##s library ." # START bert_proper_decoding - bert_tokenizer.decoder = tokenizers.decoders.WordPiece() + from tokenizers import decoders + + bert_tokenizer.decoder = decoders.WordPiece() bert_tokenizer.decode(output.ids) # "welcome to the tokenizers library." # END bert_proper_decoding + assert bert_tokenizer.decode(output.ids) == "welcome to the tokenizers library." + + +if __name__ == "__main__": + from urllib import request + from zipfile import ZipFile + import os + + if not os.path.isdir("data/wikitext-103-raw"): + print("Downloading wikitext-103...") + wiki_text, _ = request.urlretrieve( + "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip" + ) + with ZipFile(wiki_text, "r") as z: + print("Unzipping in data...") + z.extractall("data") + + print("Now training...") + TestPipeline.slow_train() diff --git a/bindings/python/tests/utils.py b/bindings/python/tests/utils.py index c2706174..02bfb5b2 100644 --- a/bindings/python/tests/utils.py +++ b/bindings/python/tests/utils.py @@ -90,6 +90,14 @@ def doc_wiki_tokenizer(data_dir): ) +@pytest.fixture(scope="session") +def doc_pipeline_bert_tokenizer(data_dir): + return download( + "https://s3.amazonaws.com/models.huggingface.co/bert/anthony/doc-pipeline/tokenizer.json", + "bert-wiki.json", + ) + + def multiprocessing_with_parallelism(tokenizer, enabled: bool): """ This helper can be used to test that disabling parallelism avoids dead locks when the