mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
A Model can return its associated Trainer
This commit is contained in:
@ -47,15 +47,34 @@ use crate::tokenizer::PyAddedToken;
|
||||
/// Returns:
|
||||
/// Trainer
|
||||
#[pyclass(name=Trainer)]
|
||||
#[derive(Clone)]
|
||||
#[text_signature = "(self, vocab_size=30000, min_frequency=0,show_progress=True, special_tokens=[],limit_alphabet=None, initial_alphabet = [], continuing_subword_prefix=None, end_of_word_suffix=None)"]
|
||||
pub struct PyTrainer {
|
||||
pub trainer: TrainerWrapper,
|
||||
pub trainer: Arc<TrainerWrapper>,
|
||||
}
|
||||
|
||||
impl PyTrainer {
|
||||
pub fn new(trainer: TrainerWrapper) -> Self {
|
||||
pub(crate) fn new(trainer: Arc<TrainerWrapper>) -> Self {
|
||||
PyTrainer { trainer }
|
||||
}
|
||||
|
||||
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
Ok(match self.trainer.as_ref() {
|
||||
TrainerWrapper::BpeTrainer(_) => Py::new(py, (PyBpeTrainer {}, base))?.into_py(py),
|
||||
TrainerWrapper::WordPieceTrainer(_) => {
|
||||
Py::new(py, (PyWordPieceTrainer {}, base))?.into_py(py)
|
||||
}
|
||||
TrainerWrapper::WordLevelTrainer(_) => {
|
||||
Py::new(py, (PyWordLevelTrainer {}, base))?.into_py(py)
|
||||
}
|
||||
TrainerWrapper::UnigramTrainer(_) => {
|
||||
Py::new(py, (PyUnigramTrainer {}, base))?.into_py(py)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Trainer for PyTrainer {
|
||||
@ -77,6 +96,17 @@ impl Trainer for PyTrainer {
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> From<I> for PyTrainer
|
||||
where
|
||||
I: Into<TrainerWrapper>,
|
||||
{
|
||||
fn from(trainer: I) -> Self {
|
||||
PyTrainer {
|
||||
trainer: trainer.into().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Capable of training a BPE model
|
||||
#[pyclass(extends=PyTrainer, name=BpeTrainer)]
|
||||
pub struct PyBpeTrainer {}
|
||||
@ -138,7 +168,10 @@ impl PyBpeTrainer {
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok((PyBpeTrainer {}, PyTrainer::new(builder.build().into())))
|
||||
Ok((
|
||||
PyBpeTrainer {},
|
||||
PyTrainer::new(Arc::new(builder.build().into())),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@ -237,7 +270,7 @@ impl PyWordPieceTrainer {
|
||||
|
||||
Ok((
|
||||
PyWordPieceTrainer {},
|
||||
PyTrainer::new(builder.build().into()),
|
||||
PyTrainer::new(Arc::new(builder.build().into())),
|
||||
))
|
||||
}
|
||||
}
|
||||
@ -301,7 +334,10 @@ impl PyWordLevelTrainer {
|
||||
}
|
||||
}
|
||||
|
||||
Ok((PyWordLevelTrainer {}, PyTrainer::new(trainer.into())))
|
||||
Ok((
|
||||
PyWordLevelTrainer {},
|
||||
PyTrainer::new(Arc::new(trainer.into())),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@ -388,6 +424,9 @@ impl PyUnigramTrainer {
|
||||
builder.build().map_err(|e| {
|
||||
exceptions::PyException::new_err(format!("Cannot build UnigramTrainer: {}", e))
|
||||
})?;
|
||||
Ok((PyUnigramTrainer {}, PyTrainer::new(trainer.into())))
|
||||
Ok((
|
||||
PyUnigramTrainer {},
|
||||
PyTrainer::new(Arc::new(trainer.into())),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user