Python - Trainers can get/set their attributes

This commit is contained in:
Anthony MOI
2020-11-24 17:46:58 -05:00
committed by Anthony MOI
parent 3eb7ef6d0a
commit a351d1c604
7 changed files with 679 additions and 51 deletions

View File

@ -1,5 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use pyo3::exceptions;
use pyo3::prelude::*;
@ -10,6 +10,7 @@ use tokenizers as tk;
use crate::models::PyModel;
use crate::tokenizer::PyAddedToken;
use crate::utils::PyChar;
/// Base class for all trainers
///
@ -19,19 +20,15 @@ use crate::tokenizer::PyAddedToken;
#[derive(Clone)]
#[text_signature = "(self, vocab_size=30000, min_frequency=0,show_progress=True, special_tokens=[],limit_alphabet=None, initial_alphabet = [], continuing_subword_prefix=None, end_of_word_suffix=None)"]
pub struct PyTrainer {
pub trainer: Arc<TrainerWrapper>,
pub trainer: Arc<RwLock<TrainerWrapper>>,
}
impl PyTrainer {
pub(crate) fn new(trainer: Arc<TrainerWrapper>) -> Self {
PyTrainer { trainer }
}
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
Ok(match self.trainer.as_ref() {
Ok(match *self.trainer.as_ref().read().unwrap() {
TrainerWrapper::BpeTrainer(_) => Py::new(py, (PyBpeTrainer {}, base))?.into_py(py),
TrainerWrapper::WordPieceTrainer(_) => {
Py::new(py, (PyWordPieceTrainer {}, base))?.into_py(py)
@ -50,7 +47,7 @@ impl Trainer for PyTrainer {
type Model = PyModel;
fn should_show_progress(&self) -> bool {
self.trainer.should_show_progress()
self.trainer.read().unwrap().should_show_progress()
}
fn train(
@ -58,11 +55,14 @@ impl Trainer for PyTrainer {
words: HashMap<String, u32>,
model: &mut PyModel,
) -> tk::Result<Vec<tk::AddedToken>> {
self.trainer.train(words, &mut model.model.write().unwrap())
self.trainer
.read()
.unwrap()
.train(words, &mut model.model.write().unwrap())
}
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
self.trainer.process_tokens(words, tokens)
self.trainer.read().unwrap().process_tokens(words, tokens)
}
}
@ -72,11 +72,37 @@ where
{
fn from(trainer: I) -> Self {
PyTrainer {
trainer: trainer.into().into(),
trainer: Arc::new(RwLock::new(trainer.into())),
}
}
}
macro_rules! getter {
($self: ident, $variant: ident, $($name: tt)+) => {{
let super_ = $self.as_ref();
if let TrainerWrapper::$variant(ref trainer) = *super_.trainer.read().unwrap() {
trainer.$($name)+
} else {
unreachable!()
}
}};
}
macro_rules! setter {
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
let super_ = $self.as_ref();
if let TrainerWrapper::$variant(ref mut trainer) = *super_.trainer.write().unwrap() {
trainer.$name = $value;
}
}};
($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
let super_ = $self.as_ref();
if let TrainerWrapper::$variant(ref mut trainer) = *super_.trainer.write().unwrap() {
trainer.$name($value);
}
}};
}
/// Trainer capable of training a BPE model
///
/// Args:
@ -110,6 +136,122 @@ where
pub struct PyBpeTrainer {}
#[pymethods]
impl PyBpeTrainer {
#[getter]
fn get_vocab_size(self_: PyRef<Self>) -> usize {
getter!(self_, BpeTrainer, vocab_size)
}
#[setter]
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
setter!(self_, BpeTrainer, vocab_size, vocab_size);
}
#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
getter!(self_, BpeTrainer, min_frequency)
}
#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
setter!(self_, BpeTrainer, min_frequency, freq);
}
#[getter]
fn get_show_progress(self_: PyRef<Self>) -> bool {
getter!(self_, BpeTrainer, show_progress)
}
#[setter]
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
setter!(self_, BpeTrainer, show_progress, show_progress);
}
#[getter]
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
getter!(
self_,
BpeTrainer,
special_tokens
.iter()
.map(|tok| tok.clone().into())
.collect()
)
}
#[setter]
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
setter!(
self_,
BpeTrainer,
special_tokens,
special_tokens
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
"Special tokens must be a List[Union[str, AddedToken]]",
))
}
})
.collect::<PyResult<Vec<_>>>()?
);
Ok(())
}
#[getter]
fn get_limit_alphabet(self_: PyRef<Self>) -> Option<usize> {
getter!(self_, BpeTrainer, limit_alphabet)
}
#[setter]
fn set_limit_alphabet(self_: PyRef<Self>, limit: Option<usize>) {
setter!(self_, BpeTrainer, limit_alphabet, limit);
}
#[getter]
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
getter!(
self_,
BpeTrainer,
initial_alphabet.iter().map(|c| c.to_string()).collect()
)
}
#[setter]
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
setter!(
self_,
BpeTrainer,
initial_alphabet,
alphabet.into_iter().map(|c| c.0).collect()
);
}
#[getter]
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> Option<String> {
getter!(self_, BpeTrainer, continuing_subword_prefix.clone())
}
#[setter]
fn set_continuing_subword_prefix(self_: PyRef<Self>, prefix: Option<String>) {
setter!(self_, BpeTrainer, continuing_subword_prefix, prefix);
}
#[getter]
fn get_end_of_word_suffix(self_: PyRef<Self>) -> Option<String> {
getter!(self_, BpeTrainer, end_of_word_suffix.clone())
}
#[setter]
fn set_end_of_word_suffix(self_: PyRef<Self>, suffix: Option<String>) {
setter!(self_, BpeTrainer, end_of_word_suffix, suffix);
}
#[new]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
@ -162,10 +304,7 @@ impl PyBpeTrainer {
};
}
}
Ok((
PyBpeTrainer {},
PyTrainer::new(Arc::new(builder.build().into())),
))
Ok((PyBpeTrainer {}, builder.build().into()))
}
}
@ -203,6 +342,122 @@ impl PyBpeTrainer {
pub struct PyWordPieceTrainer {}
#[pymethods]
impl PyWordPieceTrainer {
#[getter]
fn get_vocab_size(self_: PyRef<Self>) -> usize {
getter!(self_, WordPieceTrainer, vocab_size())
}
#[setter]
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
setter!(self_, WordPieceTrainer, @set_vocab_size, vocab_size);
}
#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
getter!(self_, WordPieceTrainer, min_frequency())
}
#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
setter!(self_, WordPieceTrainer, @set_min_frequency, freq);
}
#[getter]
fn get_show_progress(self_: PyRef<Self>) -> bool {
getter!(self_, WordPieceTrainer, show_progress())
}
#[setter]
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
setter!(self_, WordPieceTrainer, @set_show_progress, show_progress);
}
#[getter]
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
getter!(
self_,
WordPieceTrainer,
special_tokens()
.iter()
.map(|tok| tok.clone().into())
.collect()
)
}
#[setter]
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
setter!(
self_,
WordPieceTrainer,
@set_special_tokens,
special_tokens
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
"Special tokens must be a List[Union[str, AddedToken]]",
))
}
})
.collect::<PyResult<Vec<_>>>()?
);
Ok(())
}
#[getter]
fn get_limit_alphabet(self_: PyRef<Self>) -> Option<usize> {
getter!(self_, WordPieceTrainer, limit_alphabet())
}
#[setter]
fn set_limit_alphabet(self_: PyRef<Self>, limit: Option<usize>) {
setter!(self_, WordPieceTrainer, @set_limit_alphabet, limit);
}
#[getter]
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
getter!(
self_,
WordPieceTrainer,
initial_alphabet().iter().map(|c| c.to_string()).collect()
)
}
#[setter]
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
setter!(
self_,
WordPieceTrainer,
@set_initial_alphabet,
alphabet.into_iter().map(|c| c.0).collect()
);
}
#[getter]
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> Option<String> {
getter!(self_, WordPieceTrainer, continuing_subword_prefix().clone())
}
#[setter]
fn set_continuing_subword_prefix(self_: PyRef<Self>, prefix: Option<String>) {
setter!(self_, WordPieceTrainer, @set_continuing_subword_prefix, prefix);
}
#[getter]
fn get_end_of_word_suffix(self_: PyRef<Self>) -> Option<String> {
getter!(self_, WordPieceTrainer, end_of_word_suffix().clone())
}
#[setter]
fn set_end_of_word_suffix(self_: PyRef<Self>, suffix: Option<String>) {
setter!(self_, WordPieceTrainer, @set_end_of_word_suffix, suffix);
}
#[new]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
@ -256,10 +511,7 @@ impl PyWordPieceTrainer {
}
}
Ok((
PyWordPieceTrainer {},
PyTrainer::new(Arc::new(builder.build().into())),
))
Ok((PyWordPieceTrainer {}, builder.build().into()))
}
}
@ -281,6 +533,73 @@ impl PyWordPieceTrainer {
pub struct PyWordLevelTrainer {}
#[pymethods]
impl PyWordLevelTrainer {
#[getter]
fn get_vocab_size(self_: PyRef<Self>) -> usize {
getter!(self_, WordLevelTrainer, vocab_size)
}
#[setter]
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
setter!(self_, WordLevelTrainer, vocab_size, vocab_size);
}
#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
getter!(self_, WordLevelTrainer, min_frequency)
}
#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
setter!(self_, WordLevelTrainer, min_frequency, freq);
}
#[getter]
fn get_show_progress(self_: PyRef<Self>) -> bool {
getter!(self_, WordLevelTrainer, show_progress)
}
#[setter]
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
setter!(self_, WordLevelTrainer, show_progress, show_progress);
}
#[getter]
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
getter!(
self_,
WordLevelTrainer,
special_tokens
.iter()
.map(|tok| tok.clone().into())
.collect()
)
}
#[setter]
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
setter!(
self_,
WordLevelTrainer,
special_tokens,
special_tokens
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
"Special tokens must be a List[Union[str, AddedToken]]",
))
}
})
.collect::<PyResult<Vec<_>>>()?
);
Ok(())
}
#[new]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
@ -327,12 +646,10 @@ impl PyWordLevelTrainer {
Ok((
PyWordLevelTrainer {},
PyTrainer::new(Arc::new(
builder
.build()
.expect("WordLevelTrainerBuilder cannot fail")
.into(),
)),
builder
.build()
.expect("WordLevelTrainerBuilder cannot fail")
.into(),
))
}
}
@ -359,6 +676,82 @@ impl PyWordLevelTrainer {
pub struct PyUnigramTrainer {}
#[pymethods]
impl PyUnigramTrainer {
#[getter]
fn get_vocab_size(self_: PyRef<Self>) -> u32 {
getter!(self_, UnigramTrainer, vocab_size)
}
#[setter]
fn set_vocab_size(self_: PyRef<Self>, vocab_size: u32) {
setter!(self_, UnigramTrainer, vocab_size, vocab_size);
}
#[getter]
fn get_show_progress(self_: PyRef<Self>) -> bool {
getter!(self_, UnigramTrainer, show_progress)
}
#[setter]
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
setter!(self_, UnigramTrainer, show_progress, show_progress);
}
#[getter]
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
getter!(
self_,
UnigramTrainer,
special_tokens
.iter()
.map(|tok| tok.clone().into())
.collect()
)
}
#[setter]
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
setter!(
self_,
UnigramTrainer,
special_tokens,
special_tokens
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
"Special tokens must be a List[Union[str, AddedToken]]",
))
}
})
.collect::<PyResult<Vec<_>>>()?
);
Ok(())
}
#[getter]
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
getter!(
self_,
UnigramTrainer,
initial_alphabet.iter().map(|c| c.to_string()).collect()
)
}
#[setter]
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
setter!(
self_,
UnigramTrainer,
initial_alphabet,
alphabet.into_iter().map(|c| c.0).collect()
);
}
#[new]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
@ -416,9 +809,6 @@ impl PyUnigramTrainer {
builder.build().map_err(|e| {
exceptions::PyException::new_err(format!("Cannot build UnigramTrainer: {}", e))
})?;
Ok((
PyUnigramTrainer {},
PyTrainer::new(Arc::new(trainer.into())),
))
Ok((PyUnigramTrainer {}, trainer.into()))
}
}