Python - Fix ByteLevel instantiation from state (#621)

This commit is contained in:
Anthony MOI
2021-02-04 10:16:05 -05:00
committed by GitHub
parent 324cb8d380
commit 57200144ca
6 changed files with 30 additions and 4 deletions

View File

@ -1,5 +1,6 @@
import pytest
import pickle
import json
from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder
@ -15,6 +16,12 @@ class TestByteLevel:
decoder = ByteLevel()
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == "My name is John"
def test_manual_reload(self):
byte_level = ByteLevel()
state = json.loads(byte_level.__getstate__())
reloaded = ByteLevel(**state)
assert isinstance(reloaded, ByteLevel)
class TestWordPiece:
def test_instantiate(self):

View File

@ -1,5 +1,6 @@
import pytest
import pickle
import json
from tokenizers.pre_tokenizers import (
PreTokenizer,
@ -39,6 +40,12 @@ class TestByteLevel:
pretok.add_prefix_space = True
assert pretok.add_prefix_space == True
def test_manual_reload(self):
byte_level = ByteLevel()
state = json.loads(byte_level.__getstate__())
reloaded = ByteLevel(**state)
assert isinstance(reloaded, ByteLevel)
class TestSplit:
def test_instantiate(self):

View File

@ -1,5 +1,6 @@
import pytest
import pickle
import json
from ..utils import data_dir, roberta_files
@ -84,6 +85,12 @@ class TestByteLevelProcessing:
assert output.tokens == ["ĠMy", "Ġname", "Ġis", "ĠJohn"]
assert output.offsets == [(0, 2), (3, 7), (8, 10), (11, 15)]
def test_manual_reload(self):
byte_level = ByteLevel()
state = json.loads(byte_level.__getstate__())
reloaded = ByteLevel(**state)
assert isinstance(reloaded, ByteLevel)
class TestTemplateProcessing:
def get_bert(self):