Add bindings for Trainer in Python

This commit is contained in:
Anthony MOI
2019-12-03 15:54:15 -05:00
parent 310a2af76b
commit eaafb22511
5 changed files with 74 additions and 1 deletions

View File

@ -3,11 +3,20 @@ mod models;
mod pre_tokenizers;
mod token;
mod tokenizer;
mod trainers;
mod utils;
use pyo3::prelude::*;
use pyo3::wrap_pymodule;
/// Trainers Module
#[pymodule]
fn trainers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<trainers::Trainer>()?;
m.add_class::<trainers::BpeTrainer>()?;
Ok(())
}
/// Models Module
#[pymodule]
fn models(_py: Python, m: &PyModule) -> PyResult<()> {
@ -39,5 +48,6 @@ fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(models))?;
m.add_wrapped(wrap_pymodule!(pre_tokenizers))?;
m.add_wrapped(wrap_pymodule!(decoders))?;
m.add_wrapped(wrap_pymodule!(trainers))?;
Ok(())
}

View File

@ -43,4 +43,11 @@ impl BPE {
}),
}
}
#[staticmethod]
fn empty() -> Model {
Model {
model: Container::Owned(Box::new(tk::models::bpe::BPE::empty())),
}
}
}

View File

@ -7,6 +7,7 @@ use super::decoders::Decoder;
use super::models::Model;
use super::pre_tokenizers::PreTokenizer;
use super::token::Token;
use super::trainers::Trainer;
#[pyclass]
pub struct Tokenizer {
@ -97,5 +98,15 @@ impl Tokenizer {
fn id_to_token(&self, id: u32) -> Option<String> {
self.tokenizer.id_to_token(id)
}
fn train(&mut self, trainer: &Trainer, files: Vec<String>) -> PyResult<()> {
trainer.trainer.execute(|trainer| {
if let Err(e) = self.tokenizer.train(trainer, files) {
Err(exceptions::Exception::py_err(format!("{}", e)))
} else {
Ok(())
}
})
}
}

View File

@ -0,0 +1,45 @@
extern crate tokenizers as tk;
use super::utils::Container;
use pyo3::prelude::*;
use pyo3::types::*;
#[pyclass]
pub struct Trainer {
pub trainer: Container<dyn tk::tokenizer::Trainer>,
}
#[pyclass]
pub struct BpeTrainer {}
#[pymethods]
impl BpeTrainer {
/// new(/vocab_size, min_frequency)
/// --
///
/// Create a new BpeTrainer with the given configuration
#[staticmethod]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<Trainer> {
let mut config: tk::models::bpe::BpeTrainerConfig = Default::default();
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"vocab_size" => {
let size: usize = val.extract()?;
config.set_vocab_size(size);
}
"min_frequency" => {
let freq: u32 = val.extract()?;
config.set_min_frequency(freq);
}
_ => println!("Ignored unknown kwargs option {}", key),
};
}
}
Ok(Trainer {
trainer: Container::Owned(Box::new(tk::models::bpe::BpeTrainer::new(config))),
})
}
}

View File

@ -1,3 +1,3 @@
__version__ = "0.0.2"
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers