mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-31 04:29:21 +00:00
Python - PyModel uses a RwLock to allow modifications
This commit is contained in:
@ -32,10 +32,6 @@ pub struct PyModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PyModel {
|
impl PyModel {
|
||||||
pub(crate) fn new(model: Arc<RwLock<ModelWrapper>>) -> Self {
|
|
||||||
PyModel { model }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
|
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
|
||||||
let base = self.clone();
|
let base = self.clone();
|
||||||
let gil = Python::acquire_gil();
|
let gil = Python::acquire_gil();
|
||||||
@ -81,6 +77,17 @@ impl Model for PyModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<I> From<I> for PyModel
|
||||||
|
where
|
||||||
|
I: Into<ModelWrapper>,
|
||||||
|
{
|
||||||
|
fn from(model: I) -> Self {
|
||||||
|
Self {
|
||||||
|
model: Arc::new(RwLock::new(model.into())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyModel {
|
impl PyModel {
|
||||||
#[new]
|
#[new]
|
||||||
@ -262,7 +269,7 @@ impl PyBPE {
|
|||||||
"Error while initializing BPE: {}",
|
"Error while initializing BPE: {}",
|
||||||
e
|
e
|
||||||
))),
|
))),
|
||||||
Ok(bpe) => Ok((PyBPE {}, PyModel::new(Arc::new(RwLock::new(bpe.into()))))),
|
Ok(bpe) => Ok((PyBPE {}, bpe.into())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -432,10 +439,7 @@ impl PyWordPiece {
|
|||||||
"Error while initializing WordPiece: {}",
|
"Error while initializing WordPiece: {}",
|
||||||
e
|
e
|
||||||
))),
|
))),
|
||||||
Ok(wordpiece) => Ok((
|
Ok(wordpiece) => Ok((PyWordPiece {}, wordpiece.into())),
|
||||||
PyWordPiece {},
|
|
||||||
PyModel::new(Arc::new(RwLock::new(wordpiece.into()))),
|
|
||||||
)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -579,15 +583,9 @@ impl PyWordLevel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((
|
Ok((PyWordLevel {}, model.into()))
|
||||||
PyWordLevel {},
|
|
||||||
PyModel::new(Arc::new(RwLock::new(model.into()))),
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
Ok((
|
Ok((PyWordLevel {}, WordLevel::default().into()))
|
||||||
PyWordLevel {},
|
|
||||||
PyModel::new(Arc::new(RwLock::new(WordLevel::default().into()))),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -661,15 +659,9 @@ impl PyUnigram {
|
|||||||
let model = Unigram::from(vocab, unk_id).map_err(|e| {
|
let model = Unigram::from(vocab, unk_id).map_err(|e| {
|
||||||
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
|
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
|
||||||
})?;
|
})?;
|
||||||
Ok((
|
Ok((PyUnigram {}, model.into()))
|
||||||
PyUnigram {},
|
|
||||||
PyModel::new(Arc::new(RwLock::new(model.into()))),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
(None, None) => Ok((
|
(None, None) => Ok((PyUnigram {}, Unigram::default().into())),
|
||||||
PyUnigram {},
|
|
||||||
PyModel::new(Arc::new(RwLock::new(Unigram::default().into()))),
|
|
||||||
)),
|
|
||||||
_ => Err(exceptions::PyValueError::new_err(
|
_ => Err(exceptions::PyValueError::new_err(
|
||||||
"`vocab` and `unk_id` must be both specified",
|
"`vocab` and `unk_id` must be both specified",
|
||||||
)),
|
)),
|
||||||
@ -681,13 +673,12 @@ impl PyUnigram {
|
|||||||
mod test {
|
mod test {
|
||||||
use crate::models::PyModel;
|
use crate::models::PyModel;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use std::sync::{Arc, RwLock};
|
|
||||||
use tk::models::bpe::BPE;
|
use tk::models::bpe::BPE;
|
||||||
use tk::models::ModelWrapper;
|
use tk::models::ModelWrapper;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn get_subtype() {
|
fn get_subtype() {
|
||||||
let py_model = PyModel::new(Arc::new(RwLock::new(BPE::default().into())));
|
let py_model = PyModel::from(BPE::default());
|
||||||
let py_bpe = py_model.get_as_subtype().unwrap();
|
let py_bpe = py_model.get_as_subtype().unwrap();
|
||||||
let gil = Python::acquire_gil();
|
let gil = Python::acquire_gil();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -703,7 +694,7 @@ mod test {
|
|||||||
let rs_wrapper: ModelWrapper = rs_bpe.into();
|
let rs_wrapper: ModelWrapper = rs_bpe.into();
|
||||||
let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap();
|
let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap();
|
||||||
|
|
||||||
let py_model = PyModel::new(Arc::new(RwLock::new(rs_wrapper)));
|
let py_model = PyModel::from(rs_wrapper);
|
||||||
let py_ser = serde_json::to_string(&py_model).unwrap();
|
let py_ser = serde_json::to_string(&py_model).unwrap();
|
||||||
assert_eq!(py_ser, rs_bpe_ser);
|
assert_eq!(py_ser, rs_bpe_ser);
|
||||||
assert_eq!(py_ser, rs_wrapper_ser);
|
assert_eq!(py_ser, rs_wrapper_ser);
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, RwLock};
|
|
||||||
|
|
||||||
use numpy::PyArray1;
|
use numpy::PyArray1;
|
||||||
use pyo3::exceptions;
|
use pyo3::exceptions;
|
||||||
@ -457,8 +456,7 @@ impl PyTokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
|
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
|
||||||
let model: PyObject =
|
let model = PyModel::from(BPE::default()).into_py(py);
|
||||||
PyModel::new(Arc::new(RwLock::new(BPE::default().into()))).into_py(py);
|
|
||||||
let args = PyTuple::new(py, vec![model]);
|
let args = PyTuple::new(py, vec![model]);
|
||||||
Ok(args)
|
Ok(args)
|
||||||
}
|
}
|
||||||
@ -1181,9 +1179,7 @@ mod test {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn serialize() {
|
fn serialize() {
|
||||||
let mut tokenizer = Tokenizer::new(PyModel::new(Arc::new(RwLock::new(
|
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
|
||||||
tk::models::bpe::BPE::default().into(),
|
|
||||||
))));
|
|
||||||
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
|
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
|
||||||
Arc::new(NFKC.into()),
|
Arc::new(NFKC.into()),
|
||||||
Arc::new(Lowercase.into()),
|
Arc::new(Lowercase.into()),
|
||||||
|
Reference in New Issue
Block a user