Python - PyModel uses a RwLock to allow modifications

This commit is contained in:
Anthony MOI
2020-11-13 16:17:20 -05:00
committed by Anthony MOI
parent dd399d2ad0
commit 7f3cfebf45
2 changed files with 21 additions and 34 deletions

View File

@ -32,10 +32,6 @@ pub struct PyModel {
}
impl PyModel {
pub(crate) fn new(model: Arc<RwLock<ModelWrapper>>) -> Self {
PyModel { model }
}
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
let base = self.clone();
let gil = Python::acquire_gil();
@ -81,6 +77,17 @@ impl Model for PyModel {
}
}
impl<I> From<I> for PyModel
where
I: Into<ModelWrapper>,
{
fn from(model: I) -> Self {
Self {
model: Arc::new(RwLock::new(model.into())),
}
}
}
#[pymethods]
impl PyModel {
#[new]
@ -262,7 +269,7 @@ impl PyBPE {
"Error while initializing BPE: {}",
e
))),
Ok(bpe) => Ok((PyBPE {}, PyModel::new(Arc::new(RwLock::new(bpe.into()))))),
Ok(bpe) => Ok((PyBPE {}, bpe.into())),
}
}
}
@ -432,10 +439,7 @@ impl PyWordPiece {
"Error while initializing WordPiece: {}",
e
))),
Ok(wordpiece) => Ok((
PyWordPiece {},
PyModel::new(Arc::new(RwLock::new(wordpiece.into()))),
)),
Ok(wordpiece) => Ok((PyWordPiece {}, wordpiece.into())),
}
}
}
@ -579,15 +583,9 @@ impl PyWordLevel {
}
};
Ok((
PyWordLevel {},
PyModel::new(Arc::new(RwLock::new(model.into()))),
))
Ok((PyWordLevel {}, model.into()))
} else {
Ok((
PyWordLevel {},
PyModel::new(Arc::new(RwLock::new(WordLevel::default().into()))),
))
Ok((PyWordLevel {}, WordLevel::default().into()))
}
}
@ -661,15 +659,9 @@ impl PyUnigram {
let model = Unigram::from(vocab, unk_id).map_err(|e| {
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
})?;
Ok((
PyUnigram {},
PyModel::new(Arc::new(RwLock::new(model.into()))),
))
Ok((PyUnigram {}, model.into()))
}
(None, None) => Ok((
PyUnigram {},
PyModel::new(Arc::new(RwLock::new(Unigram::default().into()))),
)),
(None, None) => Ok((PyUnigram {}, Unigram::default().into())),
_ => Err(exceptions::PyValueError::new_err(
"`vocab` and `unk_id` must be both specified",
)),
@ -681,13 +673,12 @@ impl PyUnigram {
mod test {
use crate::models::PyModel;
use pyo3::prelude::*;
use std::sync::{Arc, RwLock};
use tk::models::bpe::BPE;
use tk::models::ModelWrapper;
#[test]
fn get_subtype() {
let py_model = PyModel::new(Arc::new(RwLock::new(BPE::default().into())));
let py_model = PyModel::from(BPE::default());
let py_bpe = py_model.get_as_subtype().unwrap();
let gil = Python::acquire_gil();
assert_eq!(
@ -703,7 +694,7 @@ mod test {
let rs_wrapper: ModelWrapper = rs_bpe.into();
let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap();
let py_model = PyModel::new(Arc::new(RwLock::new(rs_wrapper)));
let py_model = PyModel::from(rs_wrapper);
let py_ser = serde_json::to_string(&py_model).unwrap();
assert_eq!(py_ser, rs_bpe_ser);
assert_eq!(py_ser, rs_wrapper_ser);

View File

@ -1,5 +1,4 @@
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use numpy::PyArray1;
use pyo3::exceptions;
@ -457,8 +456,7 @@ impl PyTokenizer {
}
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
let model: PyObject =
PyModel::new(Arc::new(RwLock::new(BPE::default().into()))).into_py(py);
let model = PyModel::from(BPE::default()).into_py(py);
let args = PyTuple::new(py, vec![model]);
Ok(args)
}
@ -1181,9 +1179,7 @@ mod test {
#[test]
fn serialize() {
let mut tokenizer = Tokenizer::new(PyModel::new(Arc::new(RwLock::new(
tk::models::bpe::BPE::default().into(),
))));
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
Arc::new(NFKC.into()),
Arc::new(Lowercase.into()),