Adding pickling support for trainers (#949)

* TMP.

* Adding support for pickling Python trainers.

* Remove not warranted files + missed naming updates.

* Stubbing.

* Making sure serialized format is written in python tests.
This commit is contained in:
Nicolas Patry
2022-03-14 12:18:11 +01:00
committed by GitHub
parent 71ae5421eb
commit 4b6055d4fb
11 changed files with 298 additions and 196 deletions

View File

@ -1,5 +1,6 @@
import os
import pytest
import copy
import pickle
from tokenizers import (
@ -14,7 +15,7 @@ from tokenizers import (
from ..utils import data_dir, train_files
class TestBPETrainer:
class TestBpeTrainer:
def test_can_modify(self):
trainer = trainers.BpeTrainer(
vocab_size=12345,
@ -57,6 +58,21 @@ class TestBPETrainer:
trainer.end_of_word_suffix = None
assert trainer.continuing_subword_prefix == None
def test_can_pickle(self):
assert (
trainers.BpeTrainer(min_frequency=12).__getstate__()
== b"""{"BpeTrainer":{"min_frequency":12,"vocab_size":30000,"show_progress":true,"special_tokens":[],"limit_alphabet":null,"initial_alphabet":[],"continuing_subword_prefix":null,"end_of_word_suffix":null,"words":{}}}"""
)
assert isinstance(
pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12))), trainers.BpeTrainer
)
assert isinstance(copy.deepcopy(trainers.BpeTrainer(min_frequency=12)), trainers.BpeTrainer)
# Make sure everything is correct
assert pickle.dumps(
pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12)))
) == pickle.dumps(trainers.BpeTrainer(min_frequency=12))
class TestWordPieceTrainer:
def test_can_modify(self):
@ -101,6 +117,11 @@ class TestWordPieceTrainer:
trainer.end_of_word_suffix = None
assert trainer.continuing_subword_prefix == None
def test_can_pickle(self):
assert isinstance(
pickle.loads(pickle.dumps(trainers.WordPieceTrainer())), trainers.WordPieceTrainer
)
class TestWordLevelTrainer:
def test_can_modify(self):
@ -126,6 +147,11 @@ class TestWordLevelTrainer:
trainer.special_tokens = []
assert trainer.special_tokens == []
def test_can_pickle(self):
assert isinstance(
pickle.loads(pickle.dumps(trainers.WordLevelTrainer())), trainers.WordLevelTrainer
)
class TestUnigram:
def test_train(self, train_files):
@ -157,6 +183,11 @@ class TestUnigram:
trainer = trainers.BpeTrainer(special_tokens=["<unk>"], show_progress=False)
bpe_tokenizer.train([train_files["small"]], trainer=trainer)
def test_can_pickle(self):
assert isinstance(
pickle.loads(pickle.dumps(trainers.UnigramTrainer())), trainers.UnigramTrainer
)
def test_train_with_special_tokens(self):
filename = "tests/data/dummy-unigram-special_tokens-train.txt"
with open(filename, "w") as f: