PyModel uses a RwLock to allow modifications

This commit is contained in:
Anthony MOI
2020-10-08 19:33:30 -04:00
committed by Anthony MOI
parent 54c7210b2f
commit 284a1dbee7
10 changed files with 76 additions and 69 deletions

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use crate::token::PyToken;
use crate::trainers::PyTrainer;
@ -24,11 +24,11 @@ use super::error::{deprecation_warning, ToPyResult};
#[derive(Clone, Serialize, Deserialize)]
pub struct PyModel {
#[serde(flatten)]
pub model: Arc<ModelWrapper>,
pub model: Arc<RwLock<ModelWrapper>>,
}
impl PyModel {
pub(crate) fn new(model: Arc<ModelWrapper>) -> Self {
pub(crate) fn new(model: Arc<RwLock<ModelWrapper>>) -> Self {
PyModel { model }
}
@ -36,7 +36,7 @@ impl PyModel {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
Ok(match self.model.as_ref() {
Ok(match *self.model.as_ref().read().unwrap() {
ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?.into_py(py),
ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?.into_py(py),
ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?.into_py(py),
@ -49,31 +49,31 @@ impl Model for PyModel {
type Trainer = PyTrainer;
fn tokenize(&self, tokens: &str) -> tk::Result<Vec<Token>> {
self.model.tokenize(tokens)
self.model.read().unwrap().tokenize(tokens)
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.model.token_to_id(token)
self.model.read().unwrap().token_to_id(token)
}
fn id_to_token(&self, id: u32) -> Option<&str> {
self.model.id_to_token(id)
fn id_to_token(&self, id: u32) -> Option<String> {
self.model.read().unwrap().id_to_token(id)
}
fn get_vocab(&self) -> &HashMap<String, u32> {
self.model.get_vocab()
fn get_vocab(&self) -> HashMap<String, u32> {
self.model.read().unwrap().get_vocab()
}
fn get_vocab_size(&self) -> usize {
self.model.get_vocab_size()
self.model.read().unwrap().get_vocab_size()
}
fn save(&self, folder: &Path, name: Option<&str>) -> tk::Result<Vec<PathBuf>> {
self.model.save(folder, name)
self.model.read().unwrap().save(folder, name)
}
fn get_trainer(&self) -> Self::Trainer {
self.model.get_trainer().into()
self.model.read().unwrap().get_trainer().into()
}
}
@ -84,7 +84,7 @@ impl PyModel {
// Instantiate a default empty model. This doesn't really make sense, but we need
// to be able to instantiate an empty model for pickle capabilities.
Ok(PyModel {
model: Arc::new(BPE::default().into()),
model: Arc::new(RwLock::new(BPE::default().into())),
})
}
@ -116,7 +116,7 @@ impl PyModel {
/// Tokenize the given sequence
#[text_signature = "(self, tokens)"]
fn tokenize(&self, tokens: &str) -> PyResult<Vec<PyToken>> {
Ok(ToPyResult(self.model.tokenize(tokens))
Ok(ToPyResult(self.model.read().unwrap().tokenize(tokens))
.into_py()?
.into_iter()
.map(|t| t.into())
@ -126,13 +126,13 @@ impl PyModel {
/// Returns the id associated with the given token
#[text_signature = "(self, tokens)"]
fn token_to_id(&self, token: &str) -> Option<u32> {
self.model.token_to_id(token)
self.model.read().unwrap().token_to_id(token)
}
/// Returns the token associated with the given id
#[text_signature = "(self, id)"]
fn id_to_token(&self, id: u32) -> Option<&str> {
self.model.id_to_token(id)
fn id_to_token(&self, id: u32) -> Option<String> {
self.model.read().unwrap().id_to_token(id)
}
/// Save the current model
@ -142,7 +142,8 @@ impl PyModel {
/// Any file with the same name that already exist in this folder will be overwritten.
#[text_signature = "(self, folder, name)"]
fn save(&self, folder: &str, name: Option<&str>) -> PyResult<Vec<String>> {
let saved: PyResult<Vec<_>> = ToPyResult(self.model.save(Path::new(folder), name)).into();
let saved: PyResult<Vec<_>> =
ToPyResult(self.model.read().unwrap().save(Path::new(folder), name)).into();
Ok(saved?
.into_iter()
@ -151,7 +152,7 @@ impl PyModel {
}
fn get_trainer(&self) -> PyResult<PyObject> {
PyTrainer::from(self.model.get_trainer()).get_as_subtype()
PyTrainer::from(self.model.read().unwrap().get_trainer()).get_as_subtype()
}
}
@ -219,7 +220,7 @@ impl PyBPE {
"Error while initializing BPE: {}",
e
))),
Ok(bpe) => Ok((PyBPE {}, PyModel::new(Arc::new(bpe.into())))),
Ok(bpe) => Ok((PyBPE {}, PyModel::new(Arc::new(RwLock::new(bpe.into()))))),
}
}
}
@ -360,7 +361,10 @@ impl PyWordPiece {
"Error while initializing WordPiece: {}",
e
))),
Ok(wordpiece) => Ok((PyWordPiece {}, PyModel::new(Arc::new(wordpiece.into())))),
Ok(wordpiece) => Ok((
PyWordPiece {},
PyModel::new(Arc::new(RwLock::new(wordpiece.into()))),
)),
}
}
}
@ -476,11 +480,14 @@ impl PyWordLevel {
}
};
Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into()))))
Ok((
PyWordLevel {},
PyModel::new(Arc::new(RwLock::new(model.into()))),
))
} else {
Ok((
PyWordLevel {},
PyModel::new(Arc::new(WordLevel::default().into())),
PyModel::new(Arc::new(RwLock::new(WordLevel::default().into()))),
))
}
}
@ -523,11 +530,14 @@ impl PyUnigram {
let model = Unigram::from(vocab, unk_id).map_err(|e| {
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
})?;
Ok((PyUnigram {}, PyModel::new(Arc::new(model.into()))))
Ok((
PyUnigram {},
PyModel::new(Arc::new(RwLock::new(model.into()))),
))
}
(None, None) => Ok((
PyUnigram {},
PyModel::new(Arc::new(Unigram::default().into())),
PyModel::new(Arc::new(RwLock::new(Unigram::default().into()))),
)),
_ => Err(exceptions::PyValueError::new_err(
"`vocab` and `unk_id` must be both specified",
@ -540,13 +550,13 @@ impl PyUnigram {
mod test {
use crate::models::PyModel;
use pyo3::prelude::*;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use tk::models::bpe::BPE;
use tk::models::ModelWrapper;
#[test]
fn get_subtype() {
let py_model = PyModel::new(Arc::new(BPE::default().into()));
let py_model = PyModel::new(Arc::new(RwLock::new(BPE::default().into())));
let py_bpe = py_model.get_as_subtype().unwrap();
let gil = Python::acquire_gil();
assert_eq!(
@ -562,19 +572,19 @@ mod test {
let rs_wrapper: ModelWrapper = rs_bpe.into();
let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap();
let py_model = PyModel::new(Arc::new(rs_wrapper));
let py_model = PyModel::new(Arc::new(RwLock::new(rs_wrapper)));
let py_ser = serde_json::to_string(&py_model).unwrap();
assert_eq!(py_ser, rs_bpe_ser);
assert_eq!(py_ser, rs_wrapper_ser);
let py_model: PyModel = serde_json::from_str(&rs_bpe_ser).unwrap();
match py_model.model.as_ref() {
match *py_model.model.as_ref().read().unwrap() {
ModelWrapper::BPE(_) => (),
_ => panic!("Expected Bert postprocessor."),
}
let py_model: PyModel = serde_json::from_str(&rs_wrapper_ser).unwrap();
match py_model.model.as_ref() {
match *py_model.model.as_ref().read().unwrap() {
ModelWrapper::BPE(_) => (),
_ => panic!("Expected Bert postprocessor."),
}

View File

@ -1,5 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use numpy::PyArray1;
use pyo3::exceptions;
@ -457,7 +457,8 @@ impl PyTokenizer {
}
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
let model: PyObject = PyModel::new(Arc::new(BPE::default().into())).into_py(py);
let model: PyObject =
PyModel::new(Arc::new(RwLock::new(BPE::default().into()))).into_py(py);
let args = PyTuple::new(py, vec![model]);
Ok(args)
}
@ -965,7 +966,7 @@ impl PyTokenizer {
/// Returns:
/// :obj:`Optional[str]`: An optional token, :obj:`None` if out of vocabulary
#[text_signature = "(self, id)"]
fn id_to_token(&self, id: u32) -> Option<&str> {
fn id_to_token(&self, id: u32) -> Option<String> {
self.tokenizer.id_to_token(id)
}

View File

@ -89,8 +89,7 @@ impl Trainer for PyTrainer {
words: HashMap<String, u32>,
model: &mut PyModel,
) -> tk::Result<Vec<tk::AddedToken>> {
todo!("FIX THIS");
self.trainer.train(words, &mut model.model)
self.trainer.train(words, &mut model.model.write().unwrap())
}
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {

View File

@ -318,8 +318,8 @@ impl BPE {
}
}
pub fn get_vocab(&self) -> &Vocab {
&self.vocab
pub fn get_vocab(&self) -> Vocab {
self.vocab.clone()
}
pub fn get_unk_token(&self) -> &Option<String> {
@ -417,8 +417,8 @@ impl BPE {
impl Model for BPE {
type Trainer = BpeTrainer;
fn get_vocab(&self) -> &HashMap<String, u32> {
&self.vocab
fn get_vocab(&self) -> HashMap<String, u32> {
self.vocab.clone()
}
fn get_vocab_size(&self) -> usize {
@ -442,8 +442,8 @@ impl Model for BPE {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<&str> {
self.vocab_r.get(&id).map(String::as_ref)
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab_r.get(&id).cloned()
}
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {

View File

@ -75,7 +75,7 @@ impl Model for ModelWrapper {
}
}
fn id_to_token(&self, id: u32) -> Option<&str> {
fn id_to_token(&self, id: u32) -> Option<String> {
use ModelWrapper::*;
match self {
WordLevel(t) => t.id_to_token(id),
@ -85,7 +85,7 @@ impl Model for ModelWrapper {
}
}
fn get_vocab(&self) -> &HashMap<String, u32> {
fn get_vocab(&self) -> HashMap<String, u32> {
use ModelWrapper::*;
match self {
WordLevel(t) => t.get_vocab(),

View File

@ -409,8 +409,8 @@ impl<'a> Iterator for UnigramIterator<'a> {
impl Model for Unigram {
type Trainer = UnigramTrainer;
fn get_vocab(&self) -> &HashMap<String, u32> {
&self.token_to_ids
fn get_vocab(&self) -> HashMap<String, u32> {
self.token_to_ids.clone()
}
fn get_vocab_size(&self) -> usize {
@ -438,9 +438,9 @@ impl Model for Unigram {
self.token_to_ids.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<&str> {
fn id_to_token(&self, id: u32) -> Option<String> {
match self.vocab.get(id as usize) {
Some(item) => Some(&item.0),
Some(item) => Some(item.0.clone()),
None => None,
}
}

View File

@ -183,12 +183,12 @@ impl Model for WordLevel {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<&str> {
self.vocab_r.get(&id).map(String::as_ref)
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab_r.get(&id).cloned()
}
fn get_vocab(&self) -> &HashMap<String, u32> {
&self.vocab
fn get_vocab(&self) -> HashMap<String, u32> {
self.vocab.clone()
}
fn get_vocab_size(&self) -> usize {

View File

@ -185,10 +185,7 @@ impl WordPiece {
/// Create a `WordPiece` model from a `BPE` model.
pub fn from_bpe(bpe: &BPE) -> Self {
let mut wp = Self::builder()
.vocab(bpe.get_vocab().clone())
.build()
.unwrap();
let mut wp = Self::builder().vocab(bpe.get_vocab()).build().unwrap();
if let Some(unk) = bpe.get_unk_token() {
wp.unk_token = unk.to_owned();
}
@ -202,8 +199,8 @@ impl WordPiece {
impl Model for WordPiece {
type Trainer = WordPieceTrainer;
fn get_vocab(&self) -> &HashMap<String, u32> {
&self.vocab
fn get_vocab(&self) -> HashMap<String, u32> {
self.vocab.clone()
}
fn get_vocab_size(&self) -> usize {
@ -275,8 +272,8 @@ impl Model for WordPiece {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<&str> {
self.vocab_r.get(&id).map(String::as_ref)
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab_r.get(&id).cloned()
}
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {

View File

@ -202,10 +202,10 @@ impl AddedVocabulary {
}
/// Get the token matching the given id if it exists
pub fn id_to_token<'s>(&'s self, id: u32, model: &'s impl Model) -> Option<&'s str> {
pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option<String> {
self.added_tokens_map_r
.get(&id)
.map(|t| t.content.as_ref())
.map(|t| t.content.clone())
.or_else(|| model.id_to_token(id))
}
@ -550,11 +550,11 @@ mod tests {
fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<&str> {
self.vocab_r.get(&id).map(String::as_ref)
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab_r.get(&id).cloned()
}
fn get_vocab(&self) -> &HashMap<String, u32> {
&self.vocab
fn get_vocab(&self) -> HashMap<String, u32> {
self.vocab.clone()
}
fn get_vocab_size(&self) -> usize {
self.vocab.len()

View File

@ -75,9 +75,9 @@ pub trait Model {
/// Find the ID associated to a string token
fn token_to_id(&self, token: &str) -> Option<u32>;
/// Find the string token associated to an ID
fn id_to_token(&self, id: u32) -> Option<&str>;
fn id_to_token(&self, id: u32) -> Option<String>;
/// Retrieve the entire vocabulary mapping (token -> ID)
fn get_vocab(&self) -> &HashMap<String, u32>;
fn get_vocab(&self) -> HashMap<String, u32>;
/// Retrieve the size of the vocabulary
fn get_vocab_size(&self) -> usize;
/// Save the current `Model` in the given folder, using the given `prefix` for the various
@ -616,7 +616,7 @@ where
}
/// Converts an id to the corresponding token.
pub fn id_to_token(&self, id: u32) -> Option<&str> {
pub fn id_to_token(&self, id: u32) -> Option<String> {
self.added_vocabulary.id_to_token(id, &self.model)
}