Adding pickling support for trainers (#949)

* TMP.

* Adding support for pickling Python trainers.

* Remove not warranted files + missed naming updates.

* Stubbing.

* Making sure serialized format is written in python tests.
This commit is contained in:
Nicolas Patry
2022-03-14 12:18:11 +01:00
committed by GitHub
parent 71ae5421eb
commit 4b6055d4fb
11 changed files with 298 additions and 196 deletions

View File

@ -1,28 +1,32 @@
use std::sync::{Arc, RwLock};
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use tk::models::TrainerWrapper;
use tk::Trainer;
use tokenizers as tk;
use crate::models::PyModel;
use crate::tokenizer::PyAddedToken;
use crate::utils::PyChar;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use serde::{Deserialize, Serialize};
use tk::models::TrainerWrapper;
use tk::Trainer;
use tokenizers as tk;
/// Base class for all trainers
///
/// This class is not supposed to be instantiated directly. Instead, any implementation of a
/// Trainer will return an instance of this class when instantiated.
#[pyclass(name=Trainer, module = "tokenizers.trainers", name=Trainer)]
#[derive(Clone)]
#[text_signature = "(self, vocab_size=30000, min_frequency=0,show_progress=True, special_tokens=[],limit_alphabet=None, initial_alphabet = [], continuing_subword_prefix=None, end_of_word_suffix=None)"]
#[derive(Clone, Deserialize, Serialize)]
pub struct PyTrainer {
#[serde(flatten)]
pub trainer: Arc<RwLock<TrainerWrapper>>,
}
impl PyTrainer {
#[cfg(test)]
pub(crate) fn new(trainer: Arc<RwLock<TrainerWrapper>>) -> Self {
PyTrainer { trainer }
}
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
let base = self.clone();
let gil = Python::acquire_gil();
@ -41,6 +45,34 @@ impl PyTrainer {
})
}
}
#[pymethods]
impl PyTrainer {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.trainer).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to pickle PyTrainer: {}",
e
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
}
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PyTrainer: {}",
e
))
})?;
self.trainer = unpickled;
Ok(())
}
Err(e) => Err(e),
}
}
}
impl Trainer for PyTrainer {
type Model = PyModel;
@ -820,3 +852,20 @@ impl PyUnigramTrainer {
Ok((PyUnigramTrainer {}, trainer.into()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tk::models::bpe::trainer::BpeTrainer;
#[test]
fn get_subtype() {
let py_trainer = PyTrainer::new(Arc::new(RwLock::new(BpeTrainer::default().into())));
let py_bpe = py_trainer.get_as_subtype().unwrap();
let gil = Python::acquire_gil();
assert_eq!(
"tokenizers.trainers.BpeTrainer",
py_bpe.as_ref(gil.python()).get_type().name()
);
}
}