Remove Send + Sync requirements from Model.

This commit is contained in:
Sebastian Puetz
2020-07-31 17:16:04 +02:00
committed by Anthony MOI
parent 42b810488f
commit aaf8e932b1
3 changed files with 20 additions and 11 deletions

View File

@ -14,17 +14,18 @@ use tk::{Model, Token};
use tokenizers as tk; use tokenizers as tk;
use super::error::ToPyResult; use super::error::ToPyResult;
use tk::models::ModelWrapper;
/// A Model represents some tokenization algorithm like BPE or Word /// A Model represents some tokenization algorithm like BPE or Word
/// This class cannot be constructed directly. Please use one of the concrete models. /// This class cannot be constructed directly. Please use one of the concrete models.
#[pyclass(module = "tokenizers.models", name=Model)] #[pyclass(module = "tokenizers.models", name=Model)]
#[derive(Clone)] #[derive(Clone)]
pub struct PyModel { pub struct PyModel {
pub model: Arc<dyn Model>, pub model: Arc<ModelWrapper>,
} }
impl PyModel { impl PyModel {
pub(crate) fn new(model: Arc<dyn Model>) -> Self { pub(crate) fn new(model: Arc<ModelWrapper>) -> Self {
PyModel { model } PyModel { model }
} }
} }
@ -83,7 +84,7 @@ impl PyModel {
// Instantiate a default empty model. This doesn't really make sense, but we need // Instantiate a default empty model. This doesn't really make sense, but we need
// to be able to instantiate an empty model for pickle capabilities. // to be able to instantiate an empty model for pickle capabilities.
Ok(PyModel { Ok(PyModel {
model: Arc::new(BPE::default()), model: Arc::new(BPE::default().into()),
}) })
} }
@ -175,7 +176,7 @@ impl PyBPE {
"Error while initializing BPE: {}", "Error while initializing BPE: {}",
e e
))), ))),
Ok(bpe) => Ok((PyBPE {}, PyModel::new(Arc::new(bpe)))), Ok(bpe) => Ok((PyBPE {}, PyModel::new(Arc::new(bpe.into())))),
} }
} }
} }
@ -220,7 +221,7 @@ impl PyWordPiece {
"Error while initializing WordPiece", "Error while initializing WordPiece",
)) ))
} }
Ok(wordpiece) => Ok((PyWordPiece {}, PyModel::new(Arc::new(wordpiece)))), Ok(wordpiece) => Ok((PyWordPiece {}, PyModel::new(Arc::new(wordpiece.into())))),
} }
} }
} }
@ -253,10 +254,10 @@ impl PyWordLevel {
"Error while initializing WordLevel", "Error while initializing WordLevel",
)) ))
} }
Ok(model) => Ok((PyWordLevel {}, PyModel::new(Arc::new(model)))), Ok(model) => Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into())))),
} }
} else { } else {
Ok((PyWordLevel {}, PyModel::new(Arc::new(WordLevel::default())))) Ok((PyWordLevel {}, PyModel::new(Arc::new(WordLevel::default().into()))))
} }
} }
} }

View File

@ -318,7 +318,7 @@ impl PyTokenizer {
} }
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> { fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
let model: PyObject = PyModel::new(Arc::new(BPE::default())).into_py(py); let model: PyObject = PyModel::new(Arc::new(BPE::default().into())).into_py(py);
let args = PyTuple::new(py, vec![model]); let args = PyTuple::new(py, vec![model]);
Ok(args) Ok(args)
} }

View File

@ -70,7 +70,7 @@ pub trait PreTokenizer: Send + Sync {
#[typetag::serde(tag = "type")] #[typetag::serde(tag = "type")]
/// Represents a model used during Tokenization (like BPE or Word or Unigram). /// Represents a model used during Tokenization (like BPE or Word or Unigram).
pub trait Model: Send + Sync { pub trait Model {
/// Tokenize the given sequence into multiple underlying `Token`. The `offsets` on the `Token` /// Tokenize the given sequence into multiple underlying `Token`. The `offsets` on the `Token`
/// are expected to be relative to the given sequence. /// are expected to be relative to the given sequence.
fn tokenize(&self, sequence: &str) -> Result<Vec<Token>>; fn tokenize(&self, sequence: &str) -> Result<Vec<Token>>;
@ -709,7 +709,10 @@ where
&self, &self,
inputs: Vec<E>, inputs: Vec<E>,
add_special_tokens: bool, add_special_tokens: bool,
) -> Result<Vec<Encoding>> { ) -> Result<Vec<Encoding>>
where
M: Send + Sync,
{
let mut encodings = inputs let mut encodings = inputs
.into_maybe_par_iter() .into_maybe_par_iter()
.map(|input| self.encode(input, add_special_tokens)) .map(|input| self.encode(input, add_special_tokens))
@ -749,7 +752,10 @@ where
&self, &self,
sentences: Vec<Vec<u32>>, sentences: Vec<Vec<u32>>,
skip_special_tokens: bool, skip_special_tokens: bool,
) -> Result<Vec<String>> { ) -> Result<Vec<String>>
where
M: Send + Sync,
{
sentences sentences
.into_maybe_par_iter() .into_maybe_par_iter()
.map(|sentence| self.decode(sentence, skip_special_tokens)) .map(|sentence| self.decode(sentence, skip_special_tokens))
@ -761,6 +767,7 @@ where
where where
T: Trainer<Model = MN>, T: Trainer<Model = MN>,
MN: Model, MN: Model,
M: Send + Sync,
{ {
let max_read = 1_000_000; let max_read = 1_000_000;
let mut len = 0; let mut len = 0;
@ -849,6 +856,7 @@ where
where where
T: Trainer<Model = TM>, T: Trainer<Model = TM>,
TM: Model, TM: Model,
M: Send + Sync,
{ {
let words = self.word_count(trainer, files)?; let words = self.word_count(trainer, files)?;