pyo3: update to 0.17 (#1066)

* python: update bindings to edition 2021

* python: update to pyo3 0.17

* Updating testing.

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
David Hewitt
2022-10-05 15:59:01 +01:00
committed by GitHub
parent 6113666624
commit 8129dd3309
14 changed files with 213 additions and 204 deletions

View File

@ -27,10 +27,8 @@ impl PyTrainer {
pub(crate) fn new(trainer: Arc<RwLock<TrainerWrapper>>) -> Self {
PyTrainer { trainer }
}
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
Ok(match *self.trainer.as_ref().read().unwrap() {
TrainerWrapper::BpeTrainer(_) => Py::new(py, (PyBpeTrainer {}, base))?.into_py(py),
TrainerWrapper::WordPieceTrainer(_) => {
@ -857,6 +855,17 @@ impl PyUnigramTrainer {
}
}
/// Trainers Module
#[pymodule]
pub fn trainers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyTrainer>()?;
m.add_class::<PyBpeTrainer>()?;
m.add_class::<PyWordPieceTrainer>()?;
m.add_class::<PyWordLevelTrainer>()?;
m.add_class::<PyUnigramTrainer>()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
@ -864,12 +873,10 @@ mod tests {
#[test]
fn get_subtype() {
let py_trainer = PyTrainer::new(Arc::new(RwLock::new(BpeTrainer::default().into())));
let py_bpe = py_trainer.get_as_subtype().unwrap();
let gil = Python::acquire_gil();
assert_eq!(
"BpeTrainer",
py_bpe.as_ref(gil.python()).get_type().name().unwrap()
);
Python::with_gil(|py| {
let py_trainer = PyTrainer::new(Arc::new(RwLock::new(BpeTrainer::default().into())));
let py_bpe = py_trainer.get_as_subtype(py).unwrap();
assert_eq!("BpeTrainer", py_bpe.as_ref(py).get_type().name().unwrap());
})
}
}