mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Adding pickling support for trainers (#949)
* TMP. * Adding support for pickling Python trainers. * Remove not warranted files + missed naming updates. * Stubbing. * Making sure serialized format is written in python tests.
This commit is contained in:
@ -1,28 +1,32 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use tk::models::TrainerWrapper;
|
||||
use tk::Trainer;
|
||||
use tokenizers as tk;
|
||||
|
||||
use crate::models::PyModel;
|
||||
use crate::tokenizer::PyAddedToken;
|
||||
use crate::utils::PyChar;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tk::models::TrainerWrapper;
|
||||
use tk::Trainer;
|
||||
use tokenizers as tk;
|
||||
|
||||
/// Base class for all trainers
|
||||
///
|
||||
/// This class is not supposed to be instantiated directly. Instead, any implementation of a
|
||||
/// Trainer will return an instance of this class when instantiated.
|
||||
#[pyclass(name=Trainer, module = "tokenizers.trainers", name=Trainer)]
|
||||
#[derive(Clone)]
|
||||
#[text_signature = "(self, vocab_size=30000, min_frequency=0,show_progress=True, special_tokens=[],limit_alphabet=None, initial_alphabet = [], continuing_subword_prefix=None, end_of_word_suffix=None)"]
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub struct PyTrainer {
|
||||
#[serde(flatten)]
|
||||
pub trainer: Arc<RwLock<TrainerWrapper>>,
|
||||
}
|
||||
|
||||
impl PyTrainer {
|
||||
#[cfg(test)]
|
||||
pub(crate) fn new(trainer: Arc<RwLock<TrainerWrapper>>) -> Self {
|
||||
PyTrainer { trainer }
|
||||
}
|
||||
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
@ -41,6 +45,34 @@ impl PyTrainer {
|
||||
})
|
||||
}
|
||||
}
|
||||
#[pymethods]
|
||||
impl PyTrainer {
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(&self.trainer).map_err(|e| {
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to pickle PyTrainer: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
||||
}
|
||||
|
||||
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to unpickle PyTrainer: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
self.trainer = unpickled;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Trainer for PyTrainer {
|
||||
type Model = PyModel;
|
||||
@ -820,3 +852,20 @@ impl PyUnigramTrainer {
|
||||
Ok((PyUnigramTrainer {}, trainer.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tk::models::bpe::trainer::BpeTrainer;
|
||||
|
||||
#[test]
|
||||
fn get_subtype() {
|
||||
let py_trainer = PyTrainer::new(Arc::new(RwLock::new(BpeTrainer::default().into())));
|
||||
let py_bpe = py_trainer.get_as_subtype().unwrap();
|
||||
let gil = Python::acquire_gil();
|
||||
assert_eq!(
|
||||
"tokenizers.trainers.BpeTrainer",
|
||||
py_bpe.as_ref(gil.python()).get_type().name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user