A Model can return its associated Trainer

This commit is contained in:
Anthony MOI
2020-10-07 17:44:58 -04:00
committed by Anthony MOI
parent 059d43b265
commit c230183cf6
10 changed files with 130 additions and 11 deletions

View File

@ -47,15 +47,34 @@ use crate::tokenizer::PyAddedToken;
/// Returns:
/// Trainer
#[pyclass(name=Trainer)]
#[derive(Clone)]
#[text_signature = "(self, vocab_size=30000, min_frequency=0,show_progress=True, special_tokens=[],limit_alphabet=None, initial_alphabet = [], continuing_subword_prefix=None, end_of_word_suffix=None)"]
pub struct PyTrainer {
pub trainer: TrainerWrapper,
pub trainer: Arc<TrainerWrapper>,
}
impl PyTrainer {
pub fn new(trainer: TrainerWrapper) -> Self {
pub(crate) fn new(trainer: Arc<TrainerWrapper>) -> Self {
PyTrainer { trainer }
}
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
Ok(match self.trainer.as_ref() {
TrainerWrapper::BpeTrainer(_) => Py::new(py, (PyBpeTrainer {}, base))?.into_py(py),
TrainerWrapper::WordPieceTrainer(_) => {
Py::new(py, (PyWordPieceTrainer {}, base))?.into_py(py)
}
TrainerWrapper::WordLevelTrainer(_) => {
Py::new(py, (PyWordLevelTrainer {}, base))?.into_py(py)
}
TrainerWrapper::UnigramTrainer(_) => {
Py::new(py, (PyUnigramTrainer {}, base))?.into_py(py)
}
})
}
}
impl Trainer for PyTrainer {
@ -77,6 +96,17 @@ impl Trainer for PyTrainer {
}
}
impl<I> From<I> for PyTrainer
where
I: Into<TrainerWrapper>,
{
fn from(trainer: I) -> Self {
PyTrainer {
trainer: trainer.into().into(),
}
}
}
/// Capable of training a BPE model
#[pyclass(extends=PyTrainer, name=BpeTrainer)]
pub struct PyBpeTrainer {}
@ -138,7 +168,10 @@ impl PyBpeTrainer {
};
}
}
Ok((PyBpeTrainer {}, PyTrainer::new(builder.build().into())))
Ok((
PyBpeTrainer {},
PyTrainer::new(Arc::new(builder.build().into())),
))
}
}
@ -237,7 +270,7 @@ impl PyWordPieceTrainer {
Ok((
PyWordPieceTrainer {},
PyTrainer::new(builder.build().into()),
PyTrainer::new(Arc::new(builder.build().into())),
))
}
}
@ -301,7 +334,10 @@ impl PyWordLevelTrainer {
}
}
Ok((PyWordLevelTrainer {}, PyTrainer::new(trainer.into())))
Ok((
PyWordLevelTrainer {},
PyTrainer::new(Arc::new(trainer.into())),
))
}
}
@ -388,6 +424,9 @@ impl PyUnigramTrainer {
builder.build().map_err(|e| {
exceptions::PyException::new_err(format!("Cannot build UnigramTrainer: {}", e))
})?;
Ok((PyUnigramTrainer {}, PyTrainer::new(trainer.into())))
Ok((
PyUnigramTrainer {},
PyTrainer::new(Arc::new(trainer.into())),
))
}
}