mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-30 12:09:30 +00:00
Python - PyModel uses a RwLock to allow modifications
This commit is contained in:
@ -32,10 +32,6 @@ pub struct PyModel {
|
||||
}
|
||||
|
||||
impl PyModel {
|
||||
pub(crate) fn new(model: Arc<RwLock<ModelWrapper>>) -> Self {
|
||||
PyModel { model }
|
||||
}
|
||||
|
||||
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
@ -81,6 +77,17 @@ impl Model for PyModel {
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> From<I> for PyModel
|
||||
where
|
||||
I: Into<ModelWrapper>,
|
||||
{
|
||||
fn from(model: I) -> Self {
|
||||
Self {
|
||||
model: Arc::new(RwLock::new(model.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyModel {
|
||||
#[new]
|
||||
@ -262,7 +269,7 @@ impl PyBPE {
|
||||
"Error while initializing BPE: {}",
|
||||
e
|
||||
))),
|
||||
Ok(bpe) => Ok((PyBPE {}, PyModel::new(Arc::new(RwLock::new(bpe.into()))))),
|
||||
Ok(bpe) => Ok((PyBPE {}, bpe.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -432,10 +439,7 @@ impl PyWordPiece {
|
||||
"Error while initializing WordPiece: {}",
|
||||
e
|
||||
))),
|
||||
Ok(wordpiece) => Ok((
|
||||
PyWordPiece {},
|
||||
PyModel::new(Arc::new(RwLock::new(wordpiece.into()))),
|
||||
)),
|
||||
Ok(wordpiece) => Ok((PyWordPiece {}, wordpiece.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -579,15 +583,9 @@ impl PyWordLevel {
|
||||
}
|
||||
};
|
||||
|
||||
Ok((
|
||||
PyWordLevel {},
|
||||
PyModel::new(Arc::new(RwLock::new(model.into()))),
|
||||
))
|
||||
Ok((PyWordLevel {}, model.into()))
|
||||
} else {
|
||||
Ok((
|
||||
PyWordLevel {},
|
||||
PyModel::new(Arc::new(RwLock::new(WordLevel::default().into()))),
|
||||
))
|
||||
Ok((PyWordLevel {}, WordLevel::default().into()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -661,15 +659,9 @@ impl PyUnigram {
|
||||
let model = Unigram::from(vocab, unk_id).map_err(|e| {
|
||||
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
|
||||
})?;
|
||||
Ok((
|
||||
PyUnigram {},
|
||||
PyModel::new(Arc::new(RwLock::new(model.into()))),
|
||||
))
|
||||
Ok((PyUnigram {}, model.into()))
|
||||
}
|
||||
(None, None) => Ok((
|
||||
PyUnigram {},
|
||||
PyModel::new(Arc::new(RwLock::new(Unigram::default().into()))),
|
||||
)),
|
||||
(None, None) => Ok((PyUnigram {}, Unigram::default().into())),
|
||||
_ => Err(exceptions::PyValueError::new_err(
|
||||
"`vocab` and `unk_id` must be both specified",
|
||||
)),
|
||||
@ -681,13 +673,12 @@ impl PyUnigram {
|
||||
mod test {
|
||||
use crate::models::PyModel;
|
||||
use pyo3::prelude::*;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tk::models::bpe::BPE;
|
||||
use tk::models::ModelWrapper;
|
||||
|
||||
#[test]
|
||||
fn get_subtype() {
|
||||
let py_model = PyModel::new(Arc::new(RwLock::new(BPE::default().into())));
|
||||
let py_model = PyModel::from(BPE::default());
|
||||
let py_bpe = py_model.get_as_subtype().unwrap();
|
||||
let gil = Python::acquire_gil();
|
||||
assert_eq!(
|
||||
@ -703,7 +694,7 @@ mod test {
|
||||
let rs_wrapper: ModelWrapper = rs_bpe.into();
|
||||
let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap();
|
||||
|
||||
let py_model = PyModel::new(Arc::new(RwLock::new(rs_wrapper)));
|
||||
let py_model = PyModel::from(rs_wrapper);
|
||||
let py_ser = serde_json::to_string(&py_model).unwrap();
|
||||
assert_eq!(py_ser, rs_bpe_ser);
|
||||
assert_eq!(py_ser, rs_wrapper_ser);
|
||||
|
@ -1,5 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use numpy::PyArray1;
|
||||
use pyo3::exceptions;
|
||||
@ -457,8 +456,7 @@ impl PyTokenizer {
|
||||
}
|
||||
|
||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
|
||||
let model: PyObject =
|
||||
PyModel::new(Arc::new(RwLock::new(BPE::default().into()))).into_py(py);
|
||||
let model = PyModel::from(BPE::default()).into_py(py);
|
||||
let args = PyTuple::new(py, vec![model]);
|
||||
Ok(args)
|
||||
}
|
||||
@ -1181,9 +1179,7 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn serialize() {
|
||||
let mut tokenizer = Tokenizer::new(PyModel::new(Arc::new(RwLock::new(
|
||||
tk::models::bpe::BPE::default().into(),
|
||||
))));
|
||||
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
|
||||
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
|
||||
Arc::new(NFKC.into()),
|
||||
Arc::new(Lowercase.into()),
|
||||
|
Reference in New Issue
Block a user