Replace Model and Trainer Containers.

* Implement changes necessary from generic Model in Tokenizer.
* Temporarily disable training in Python since Clone can't be
  derived for Model until all components have been replaced.
* Prefix Python types in Rust with Py.
This commit is contained in:
Sebastian Pütz
2020-07-25 03:57:16 +02:00
committed by Anthony MOI
parent cdef780aa8
commit 83a52c8080
17 changed files with 360 additions and 299 deletions

View File

@ -1,27 +1,57 @@
extern crate tokenizers as tk;
use std::collections::HashMap;
use std::sync::Arc;
use super::utils::Container;
use crate::tokenizer::AddedToken;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use tk::models::TrainerWrapper;
use tk::Trainer;
use tokenizers as tk;
#[pyclass]
pub struct Trainer {
pub trainer: Container<dyn tk::tokenizer::Trainer>,
use crate::models::PyModel;
use crate::tokenizer::PyAddedToken;
#[pyclass(name=Trainer)]
pub struct PyTrainer {
pub trainer: Arc<TrainerWrapper>,
}
#[pyclass(extends=Trainer)]
pub struct BpeTrainer {}
impl PyTrainer {
pub fn new(trainer: Arc<TrainerWrapper>) -> Self {
PyTrainer { trainer }
}
}
impl Trainer for PyTrainer {
type Model = PyModel;
fn should_show_progress(&self) -> bool {
self.trainer.should_show_progress()
}
fn train(&self, words: HashMap<String, u32>) -> tk::Result<(PyModel, Vec<tk::AddedToken>)> {
self.trainer.train(words).map(|(m, t)| {
let m = PyModel { model: Arc::new(m) };
(m, t)
})
}
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
self.trainer.process_tokens(words, tokens)
}
}
#[pyclass(extends=PyTrainer, name=BpeTrainer)]
pub struct PyBpeTrainer {}
#[pymethods]
impl BpeTrainer {
impl PyBpeTrainer {
/// new(/ vocab_size, min_frequency)
/// --
///
/// Create a new BpeTrainer with the given configuration
#[new]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Trainer)> {
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
let mut builder = tk::models::bpe::BpeTrainer::builder();
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
@ -36,9 +66,9 @@ impl BpeTrainer {
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(AddedToken::from(content, Some(true)).get_token())
Ok(PyAddedToken::from(content, Some(true)).get_token())
} else if let Ok(mut token) =
token.extract::<PyRefMut<AddedToken>>()
token.extract::<PyRefMut<PyAddedToken>>()
{
token.is_special_token = true;
Ok(token.get_token())
@ -72,25 +102,23 @@ impl BpeTrainer {
}
}
Ok((
BpeTrainer {},
Trainer {
trainer: Container::Owned(Box::new(builder.build())),
},
PyBpeTrainer {},
PyTrainer::new(Arc::new(builder.build().into())),
))
}
}
#[pyclass(extends=Trainer)]
pub struct WordPieceTrainer {}
#[pyclass(extends=PyTrainer, name=WordPieceTrainer)]
pub struct PyWordPieceTrainer {}
#[pymethods]
impl WordPieceTrainer {
impl PyWordPieceTrainer {
/// new(/ vocab_size, min_frequency)
/// --
///
/// Create a new BpeTrainer with the given configuration
#[new]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Trainer)> {
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
@ -105,9 +133,9 @@ impl WordPieceTrainer {
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(AddedToken::from(content, Some(true)).get_token())
Ok(PyAddedToken::from(content, Some(true)).get_token())
} else if let Ok(mut token) =
token.extract::<PyRefMut<AddedToken>>()
token.extract::<PyRefMut<PyAddedToken>>()
{
token.is_special_token = true;
Ok(token.get_token())
@ -142,10 +170,8 @@ impl WordPieceTrainer {
}
Ok((
WordPieceTrainer {},
Trainer {
trainer: Container::Owned(Box::new(builder.build())),
},
PyWordPieceTrainer {},
PyTrainer::new(Arc::new(builder.build().into())),
))
}
}