diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index cd15a568..53c1b5a0 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -288,4 +288,11 @@ impl WordLevel { }), } } + + #[staticmethod] + fn empty() -> Model { + Model { + model: Container::Owned(Box::new(tk::models::wordlevel::WordLevel::default())), + } + } } diff --git a/bindings/python/tests/bindings/test_models.py b/bindings/python/tests/bindings/test_models.py new file mode 100644 index 00000000..6453db8a --- /dev/null +++ b/bindings/python/tests/bindings/test_models.py @@ -0,0 +1,23 @@ +from ..utils import data_dir, roberta_files, bert_files + +from tokenizers.models import Model, BPE, WordPiece, WordLevel + + +class TestBPE: + def test_instantiate(self, roberta_files): + assert isinstance(BPE.empty(), Model) + assert isinstance(BPE.from_files(roberta_files["vocab"], roberta_files["merges"]), Model) + + +class TestWordPiece: + def test_instantiate(self, bert_files): + assert isinstance(WordPiece.empty(), Model) + assert isinstance(WordPiece.from_files(bert_files["vocab"]), Model) + + +class TestWordLevel: + def test_instantiate(self, roberta_files): + assert isinstance(WordLevel.empty(), Model) + # The WordLevel model expects a vocab.json using the same format as roberta + # so we can just try to load with this file + assert isinstance(WordLevel.from_files(roberta_files["vocab"]), Model) diff --git a/bindings/python/tests/utils.py b/bindings/python/tests/utils.py index c60a68c9..1ba0781d 100644 --- a/bindings/python/tests/utils.py +++ b/bindings/python/tests/utils.py @@ -35,3 +35,12 @@ def roberta_files(data_dir): "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt" ), } + + +@pytest.fixture(scope="session") +def bert_files(data_dir): + return { + "vocab": download( + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt" + ), + } diff --git a/bindings/python/tokenizers/models/__init__.pyi b/bindings/python/tokenizers/models/__init__.pyi index 7c4bce3d..7cd4340d 100644 --- a/bindings/python/tokenizers/models/__init__.pyi +++ b/bindings/python/tokenizers/models/__init__.pyi @@ -158,3 +158,7 @@ class WordLevel(Model): The unknown token to be used by the model. """ pass + @staticmethod + def empty() -> Model: + """ Instantiate an empty WordLevel Model. """ + pass