diff --git a/.gitignore b/.gitignore index b116e54b..5cde67af 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ Cargo.lock /data tokenizers/data +bindings/python/tests/data /docs __pycache__ diff --git a/bindings/python/tests/__init__.py b/bindings/python/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bindings/python/tests/bindings/__init__.py b/bindings/python/tests/bindings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bindings/python/tests/utils.py b/bindings/python/tests/utils.py new file mode 100644 index 00000000..c60a68c9 --- /dev/null +++ b/bindings/python/tests/utils.py @@ -0,0 +1,37 @@ +import os +import requests +import pytest + +DATA_PATH = os.path.join("tests", "data") + + +def download(url): + filename = url.rsplit("/")[-1] + filepath = os.path.join(DATA_PATH, filename) + if not os.path.exists(filepath): + with open(filepath, "wb") as f: + response = requests.get(url, stream=True) + response.raise_for_status() + for chunk in response.iter_content(1024): + f.write(chunk) + return filepath + + +@pytest.fixture(scope="session") +def data_dir(): + assert os.getcwd().endswith("python") + exist = os.path.exists(DATA_PATH) and os.path.isdir(DATA_PATH) + if not exist: + os.mkdir(DATA_PATH) + + +@pytest.fixture(scope="session") +def roberta_files(data_dir): + return { + "vocab": download( + "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json" + ), + "merges": download( + "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt" + ), + }