mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-03 11:18:29 +00:00
Adding pickling support for trainers (#949)
* TMP. * Adding support for pickling Python trainers. * Remove not warranted files + missed naming updates. * Stubbing. * Making sure serialized format is written in python tests.
This commit is contained in:
@@ -7,19 +7,6 @@ class Trainer:
|
||||
Trainer will return an instance of this class when instantiated.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30000,
|
||||
min_frequency=0,
|
||||
show_progress=True,
|
||||
special_tokens=[],
|
||||
limit_alphabet=None,
|
||||
initial_alphabet=[],
|
||||
continuing_subword_prefix=None,
|
||||
end_of_word_suffix=None,
|
||||
):
|
||||
pass
|
||||
|
||||
class BpeTrainer(Trainer):
|
||||
"""
|
||||
Trainer capable of training a BPE model
|
||||
|
||||
@@ -1,28 +1,32 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use tk::models::TrainerWrapper;
|
||||
use tk::Trainer;
|
||||
use tokenizers as tk;
|
||||
|
||||
use crate::models::PyModel;
|
||||
use crate::tokenizer::PyAddedToken;
|
||||
use crate::utils::PyChar;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tk::models::TrainerWrapper;
|
||||
use tk::Trainer;
|
||||
use tokenizers as tk;
|
||||
|
||||
/// Base class for all trainers
|
||||
///
|
||||
/// This class is not supposed to be instantiated directly. Instead, any implementation of a
|
||||
/// Trainer will return an instance of this class when instantiated.
|
||||
#[pyclass(name=Trainer, module = "tokenizers.trainers", name=Trainer)]
|
||||
#[derive(Clone)]
|
||||
#[text_signature = "(self, vocab_size=30000, min_frequency=0,show_progress=True, special_tokens=[],limit_alphabet=None, initial_alphabet = [], continuing_subword_prefix=None, end_of_word_suffix=None)"]
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub struct PyTrainer {
|
||||
#[serde(flatten)]
|
||||
pub trainer: Arc<RwLock<TrainerWrapper>>,
|
||||
}
|
||||
|
||||
impl PyTrainer {
|
||||
#[cfg(test)]
|
||||
pub(crate) fn new(trainer: Arc<RwLock<TrainerWrapper>>) -> Self {
|
||||
PyTrainer { trainer }
|
||||
}
|
||||
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
@@ -41,6 +45,34 @@ impl PyTrainer {
|
||||
})
|
||||
}
|
||||
}
|
||||
#[pymethods]
|
||||
impl PyTrainer {
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(&self.trainer).map_err(|e| {
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to pickle PyTrainer: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
||||
}
|
||||
|
||||
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::PyException::new_err(format!(
|
||||
"Error while attempting to unpickle PyTrainer: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
self.trainer = unpickled;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Trainer for PyTrainer {
|
||||
type Model = PyModel;
|
||||
@@ -820,3 +852,20 @@ impl PyUnigramTrainer {
|
||||
Ok((PyUnigramTrainer {}, trainer.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tk::models::bpe::trainer::BpeTrainer;
|
||||
|
||||
#[test]
|
||||
fn get_subtype() {
|
||||
let py_trainer = PyTrainer::new(Arc::new(RwLock::new(BpeTrainer::default().into())));
|
||||
let py_bpe = py_trainer.get_as_subtype().unwrap();
|
||||
let gil = Python::acquire_gil();
|
||||
assert_eq!(
|
||||
"tokenizers.trainers.BpeTrainer",
|
||||
py_bpe.as_ref(gil.python()).get_type().name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import pytest
|
||||
import copy
|
||||
import pickle
|
||||
|
||||
from tokenizers import (
|
||||
@@ -14,7 +15,7 @@ from tokenizers import (
|
||||
from ..utils import data_dir, train_files
|
||||
|
||||
|
||||
class TestBPETrainer:
|
||||
class TestBpeTrainer:
|
||||
def test_can_modify(self):
|
||||
trainer = trainers.BpeTrainer(
|
||||
vocab_size=12345,
|
||||
@@ -57,6 +58,21 @@ class TestBPETrainer:
|
||||
trainer.end_of_word_suffix = None
|
||||
assert trainer.continuing_subword_prefix == None
|
||||
|
||||
def test_can_pickle(self):
|
||||
assert (
|
||||
trainers.BpeTrainer(min_frequency=12).__getstate__()
|
||||
== b"""{"BpeTrainer":{"min_frequency":12,"vocab_size":30000,"show_progress":true,"special_tokens":[],"limit_alphabet":null,"initial_alphabet":[],"continuing_subword_prefix":null,"end_of_word_suffix":null,"words":{}}}"""
|
||||
)
|
||||
assert isinstance(
|
||||
pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12))), trainers.BpeTrainer
|
||||
)
|
||||
|
||||
assert isinstance(copy.deepcopy(trainers.BpeTrainer(min_frequency=12)), trainers.BpeTrainer)
|
||||
# Make sure everything is correct
|
||||
assert pickle.dumps(
|
||||
pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12)))
|
||||
) == pickle.dumps(trainers.BpeTrainer(min_frequency=12))
|
||||
|
||||
|
||||
class TestWordPieceTrainer:
|
||||
def test_can_modify(self):
|
||||
@@ -101,6 +117,11 @@ class TestWordPieceTrainer:
|
||||
trainer.end_of_word_suffix = None
|
||||
assert trainer.continuing_subword_prefix == None
|
||||
|
||||
def test_can_pickle(self):
|
||||
assert isinstance(
|
||||
pickle.loads(pickle.dumps(trainers.WordPieceTrainer())), trainers.WordPieceTrainer
|
||||
)
|
||||
|
||||
|
||||
class TestWordLevelTrainer:
|
||||
def test_can_modify(self):
|
||||
@@ -126,6 +147,11 @@ class TestWordLevelTrainer:
|
||||
trainer.special_tokens = []
|
||||
assert trainer.special_tokens == []
|
||||
|
||||
def test_can_pickle(self):
|
||||
assert isinstance(
|
||||
pickle.loads(pickle.dumps(trainers.WordLevelTrainer())), trainers.WordLevelTrainer
|
||||
)
|
||||
|
||||
|
||||
class TestUnigram:
|
||||
def test_train(self, train_files):
|
||||
@@ -157,6 +183,11 @@ class TestUnigram:
|
||||
trainer = trainers.BpeTrainer(special_tokens=["<unk>"], show_progress=False)
|
||||
bpe_tokenizer.train([train_files["small"]], trainer=trainer)
|
||||
|
||||
def test_can_pickle(self):
|
||||
assert isinstance(
|
||||
pickle.loads(pickle.dumps(trainers.UnigramTrainer())), trainers.UnigramTrainer
|
||||
)
|
||||
|
||||
def test_train_with_special_tokens(self):
|
||||
filename = "tests/data/dummy-unigram-special_tokens-train.txt"
|
||||
with open(filename, "w") as f:
|
||||
|
||||
Reference in New Issue
Block a user