mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Fixing bad deserialization following inclusion of a default for Punctuation
. (#884)
* Fixing bad deserialization following inclusion of a default for `Punctuation`. * don't remove the type now... * Adding slow test to run on all the tokenizers of the hub. * `PartialEq` everywhere. * Forcing `type` to exist on the `pre_tokenizers`.
This commit is contained in:
@ -1,5 +1,10 @@
|
||||
from tokenizers import Tokenizer, models, normalizers
|
||||
from tokenizers import Tokenizer
|
||||
import os
|
||||
import unittest
|
||||
from .utils import data_dir, albert_base
|
||||
import json
|
||||
from huggingface_hub import HfApi, hf_hub_url, cached_download
|
||||
import tqdm
|
||||
|
||||
|
||||
class TestSerialization:
|
||||
@ -8,3 +13,70 @@ class TestSerialization:
|
||||
# This used to fail because of BufReader that would fail because the
|
||||
# file exceeds the buffer capacity
|
||||
tokenizer = Tokenizer.from_file(albert_base)
|
||||
|
||||
|
||||
def check(tokenizer_file) -> bool:
|
||||
with open(tokenizer_file, "r") as f:
|
||||
data = json.load(f)
|
||||
if "pre_tokenizer" not in data:
|
||||
return True
|
||||
if "type" not in data["pre_tokenizer"]:
|
||||
return False
|
||||
if data["pre_tokenizer"]["type"] == "Sequence":
|
||||
for pre_tok in data["pre_tokenizer"]["pretokenizers"]:
|
||||
if "type" not in pre_tok:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def slow(test_case):
|
||||
"""
|
||||
Decorator marking a test as slow.
|
||||
|
||||
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
|
||||
|
||||
"""
|
||||
if os.getenv("RUN_SLOW") != "1":
|
||||
return unittest.skip("use `RUN_SLOW=1` to run")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
@slow
|
||||
class TestFullDeserialization(unittest.TestCase):
|
||||
def test_full_deserialization_hub(self):
|
||||
# Check we can read this file.
|
||||
# This used to fail because of BufReader that would fail because the
|
||||
# file exceeds the buffer capacity
|
||||
api = HfApi()
|
||||
|
||||
not_loadable = []
|
||||
invalid_pre_tokenizer = []
|
||||
|
||||
# models = api.list_models(filter="transformers")
|
||||
# for model in tqdm.tqdm(models):
|
||||
# model_id = model.modelId
|
||||
# for model_file in model.siblings:
|
||||
# filename = model_file.rfilename
|
||||
# if filename == "tokenizer.json":
|
||||
# all_models.append((model_id, filename))
|
||||
|
||||
all_models = [("HueyNemud/das22-10-camembert_pretrained", "tokenizer.json")]
|
||||
for (model_id, filename) in tqdm.tqdm(all_models):
|
||||
tokenizer_file = cached_download(hf_hub_url(model_id, filename=filename))
|
||||
|
||||
is_ok = check(tokenizer_file)
|
||||
if not is_ok:
|
||||
print(f"{model_id} is affected by no type")
|
||||
invalid_pre_tokenizer.append(model_id)
|
||||
try:
|
||||
Tokenizer.from_file(tokenizer_file)
|
||||
except Exception as e:
|
||||
print(f"{model_id} is not loadable: {e}")
|
||||
not_loadable.append(model_id)
|
||||
except:
|
||||
print(f"{model_id} is not loadable: Rust error")
|
||||
not_loadable.append(model_id)
|
||||
|
||||
self.assertEqual(invalid_pre_tokenizer, [])
|
||||
self.assertEqual(not_loadable, [])
|
||||
|
Reference in New Issue
Block a user