mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Adding pickling support for trainers (#949)
* TMP. * Adding support for pickling Python trainers. * Remove not warranted files + missed naming updates. * Stubbing. * Making sure serialized format is written in python tests.
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import pytest
|
||||
import copy
|
||||
import pickle
|
||||
|
||||
from tokenizers import (
|
||||
@ -14,7 +15,7 @@ from tokenizers import (
|
||||
from ..utils import data_dir, train_files
|
||||
|
||||
|
||||
class TestBPETrainer:
|
||||
class TestBpeTrainer:
|
||||
def test_can_modify(self):
|
||||
trainer = trainers.BpeTrainer(
|
||||
vocab_size=12345,
|
||||
@ -57,6 +58,21 @@ class TestBPETrainer:
|
||||
trainer.end_of_word_suffix = None
|
||||
assert trainer.continuing_subword_prefix == None
|
||||
|
||||
def test_can_pickle(self):
|
||||
assert (
|
||||
trainers.BpeTrainer(min_frequency=12).__getstate__()
|
||||
== b"""{"BpeTrainer":{"min_frequency":12,"vocab_size":30000,"show_progress":true,"special_tokens":[],"limit_alphabet":null,"initial_alphabet":[],"continuing_subword_prefix":null,"end_of_word_suffix":null,"words":{}}}"""
|
||||
)
|
||||
assert isinstance(
|
||||
pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12))), trainers.BpeTrainer
|
||||
)
|
||||
|
||||
assert isinstance(copy.deepcopy(trainers.BpeTrainer(min_frequency=12)), trainers.BpeTrainer)
|
||||
# Make sure everything is correct
|
||||
assert pickle.dumps(
|
||||
pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12)))
|
||||
) == pickle.dumps(trainers.BpeTrainer(min_frequency=12))
|
||||
|
||||
|
||||
class TestWordPieceTrainer:
|
||||
def test_can_modify(self):
|
||||
@ -101,6 +117,11 @@ class TestWordPieceTrainer:
|
||||
trainer.end_of_word_suffix = None
|
||||
assert trainer.continuing_subword_prefix == None
|
||||
|
||||
def test_can_pickle(self):
|
||||
assert isinstance(
|
||||
pickle.loads(pickle.dumps(trainers.WordPieceTrainer())), trainers.WordPieceTrainer
|
||||
)
|
||||
|
||||
|
||||
class TestWordLevelTrainer:
|
||||
def test_can_modify(self):
|
||||
@ -126,6 +147,11 @@ class TestWordLevelTrainer:
|
||||
trainer.special_tokens = []
|
||||
assert trainer.special_tokens == []
|
||||
|
||||
def test_can_pickle(self):
|
||||
assert isinstance(
|
||||
pickle.loads(pickle.dumps(trainers.WordLevelTrainer())), trainers.WordLevelTrainer
|
||||
)
|
||||
|
||||
|
||||
class TestUnigram:
|
||||
def test_train(self, train_files):
|
||||
@ -157,6 +183,11 @@ class TestUnigram:
|
||||
trainer = trainers.BpeTrainer(special_tokens=["<unk>"], show_progress=False)
|
||||
bpe_tokenizer.train([train_files["small"]], trainer=trainer)
|
||||
|
||||
def test_can_pickle(self):
|
||||
assert isinstance(
|
||||
pickle.loads(pickle.dumps(trainers.UnigramTrainer())), trainers.UnigramTrainer
|
||||
)
|
||||
|
||||
def test_train_with_special_tokens(self):
|
||||
filename = "tests/data/dummy-unigram-special_tokens-train.txt"
|
||||
with open(filename, "w") as f:
|
||||
|
Reference in New Issue
Block a user