Adding pickling support for trainers (#949)

* TMP.

* Adding support for pickling Python trainers.

* Remove not warranted files + missed naming updates.

* Stubbing.

* Making sure serialized format is written in python tests.
This commit is contained in:
Nicolas Patry
2022-03-14 12:18:11 +01:00
committed by GitHub
parent 71ae5421eb
commit 4b6055d4fb
11 changed files with 298 additions and 196 deletions

View File

@@ -7,19 +7,6 @@ class Trainer:
Trainer will return an instance of this class when instantiated.
"""
def __init__(
self,
vocab_size=30000,
min_frequency=0,
show_progress=True,
special_tokens=[],
limit_alphabet=None,
initial_alphabet=[],
continuing_subword_prefix=None,
end_of_word_suffix=None,
):
pass
class BpeTrainer(Trainer):
"""
Trainer capable of training a BPE model

View File

@@ -1,28 +1,32 @@
use std::sync::{Arc, RwLock};
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use tk::models::TrainerWrapper;
use tk::Trainer;
use tokenizers as tk;
use crate::models::PyModel;
use crate::tokenizer::PyAddedToken;
use crate::utils::PyChar;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use serde::{Deserialize, Serialize};
use tk::models::TrainerWrapper;
use tk::Trainer;
use tokenizers as tk;
/// Base class for all trainers
///
/// This class is not supposed to be instantiated directly. Instead, any implementation of a
/// Trainer will return an instance of this class when instantiated.
#[pyclass(name=Trainer, module = "tokenizers.trainers", name=Trainer)]
#[derive(Clone)]
#[text_signature = "(self, vocab_size=30000, min_frequency=0,show_progress=True, special_tokens=[],limit_alphabet=None, initial_alphabet = [], continuing_subword_prefix=None, end_of_word_suffix=None)"]
#[derive(Clone, Deserialize, Serialize)]
pub struct PyTrainer {
#[serde(flatten)]
pub trainer: Arc<RwLock<TrainerWrapper>>,
}
impl PyTrainer {
#[cfg(test)]
pub(crate) fn new(trainer: Arc<RwLock<TrainerWrapper>>) -> Self {
PyTrainer { trainer }
}
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
let base = self.clone();
let gil = Python::acquire_gil();
@@ -41,6 +45,34 @@ impl PyTrainer {
})
}
}
#[pymethods]
impl PyTrainer {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.trainer).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to pickle PyTrainer: {}",
e
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
}
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PyTrainer: {}",
e
))
})?;
self.trainer = unpickled;
Ok(())
}
Err(e) => Err(e),
}
}
}
impl Trainer for PyTrainer {
type Model = PyModel;
@@ -820,3 +852,20 @@ impl PyUnigramTrainer {
Ok((PyUnigramTrainer {}, trainer.into()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tk::models::bpe::trainer::BpeTrainer;
#[test]
fn get_subtype() {
let py_trainer = PyTrainer::new(Arc::new(RwLock::new(BpeTrainer::default().into())));
let py_bpe = py_trainer.get_as_subtype().unwrap();
let gil = Python::acquire_gil();
assert_eq!(
"tokenizers.trainers.BpeTrainer",
py_bpe.as_ref(gil.python()).get_type().name()
);
}
}

View File

@@ -1,5 +1,6 @@
import os
import pytest
import copy
import pickle
from tokenizers import (
@@ -14,7 +15,7 @@ from tokenizers import (
from ..utils import data_dir, train_files
class TestBPETrainer:
class TestBpeTrainer:
def test_can_modify(self):
trainer = trainers.BpeTrainer(
vocab_size=12345,
@@ -57,6 +58,21 @@ class TestBPETrainer:
trainer.end_of_word_suffix = None
assert trainer.continuing_subword_prefix == None
def test_can_pickle(self):
assert (
trainers.BpeTrainer(min_frequency=12).__getstate__()
== b"""{"BpeTrainer":{"min_frequency":12,"vocab_size":30000,"show_progress":true,"special_tokens":[],"limit_alphabet":null,"initial_alphabet":[],"continuing_subword_prefix":null,"end_of_word_suffix":null,"words":{}}}"""
)
assert isinstance(
pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12))), trainers.BpeTrainer
)
assert isinstance(copy.deepcopy(trainers.BpeTrainer(min_frequency=12)), trainers.BpeTrainer)
# Make sure everything is correct
assert pickle.dumps(
pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12)))
) == pickle.dumps(trainers.BpeTrainer(min_frequency=12))
class TestWordPieceTrainer:
def test_can_modify(self):
@@ -101,6 +117,11 @@ class TestWordPieceTrainer:
trainer.end_of_word_suffix = None
assert trainer.continuing_subword_prefix == None
def test_can_pickle(self):
assert isinstance(
pickle.loads(pickle.dumps(trainers.WordPieceTrainer())), trainers.WordPieceTrainer
)
class TestWordLevelTrainer:
def test_can_modify(self):
@@ -126,6 +147,11 @@ class TestWordLevelTrainer:
trainer.special_tokens = []
assert trainer.special_tokens == []
def test_can_pickle(self):
assert isinstance(
pickle.loads(pickle.dumps(trainers.WordLevelTrainer())), trainers.WordLevelTrainer
)
class TestUnigram:
def test_train(self, train_files):
@@ -157,6 +183,11 @@ class TestUnigram:
trainer = trainers.BpeTrainer(special_tokens=["<unk>"], show_progress=False)
bpe_tokenizer.train([train_files["small"]], trainer=trainer)
def test_can_pickle(self):
assert isinstance(
pickle.loads(pickle.dumps(trainers.UnigramTrainer())), trainers.UnigramTrainer
)
def test_train_with_special_tokens(self):
filename = "tests/data/dummy-unigram-special_tokens-train.txt"
with open(filename, "w") as f: