mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add bindings for Trainer in Python
This commit is contained in:
@ -3,11 +3,20 @@ mod models;
|
||||
mod pre_tokenizers;
|
||||
mod token;
|
||||
mod tokenizer;
|
||||
mod trainers;
|
||||
mod utils;
|
||||
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pymodule;
|
||||
|
||||
/// Trainers Module
|
||||
#[pymodule]
|
||||
fn trainers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<trainers::Trainer>()?;
|
||||
m.add_class::<trainers::BpeTrainer>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Models Module
|
||||
#[pymodule]
|
||||
fn models(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
@ -39,5 +48,6 @@ fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_wrapped(wrap_pymodule!(models))?;
|
||||
m.add_wrapped(wrap_pymodule!(pre_tokenizers))?;
|
||||
m.add_wrapped(wrap_pymodule!(decoders))?;
|
||||
m.add_wrapped(wrap_pymodule!(trainers))?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -43,4 +43,11 @@ impl BPE {
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn empty() -> Model {
|
||||
Model {
|
||||
model: Container::Owned(Box::new(tk::models::bpe::BPE::empty())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ use super::decoders::Decoder;
|
||||
use super::models::Model;
|
||||
use super::pre_tokenizers::PreTokenizer;
|
||||
use super::token::Token;
|
||||
use super::trainers::Trainer;
|
||||
|
||||
#[pyclass]
|
||||
pub struct Tokenizer {
|
||||
@ -97,5 +98,15 @@ impl Tokenizer {
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
self.tokenizer.id_to_token(id)
|
||||
}
|
||||
|
||||
fn train(&mut self, trainer: &Trainer, files: Vec<String>) -> PyResult<()> {
|
||||
trainer.trainer.execute(|trainer| {
|
||||
if let Err(e) = self.tokenizer.train(trainer, files) {
|
||||
Err(exceptions::Exception::py_err(format!("{}", e)))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
45
bindings/python/src/trainers.rs
Normal file
45
bindings/python/src/trainers.rs
Normal file
@ -0,0 +1,45 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use super::utils::Container;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
|
||||
#[pyclass]
|
||||
pub struct Trainer {
|
||||
pub trainer: Container<dyn tk::tokenizer::Trainer>,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub struct BpeTrainer {}
|
||||
#[pymethods]
|
||||
impl BpeTrainer {
|
||||
/// new(/vocab_size, min_frequency)
|
||||
/// --
|
||||
///
|
||||
/// Create a new BpeTrainer with the given configuration
|
||||
#[staticmethod]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<Trainer> {
|
||||
let mut config: tk::models::bpe::BpeTrainerConfig = Default::default();
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, val) in kwargs {
|
||||
let key: &str = key.extract()?;
|
||||
match key {
|
||||
"vocab_size" => {
|
||||
let size: usize = val.extract()?;
|
||||
config.set_vocab_size(size);
|
||||
}
|
||||
"min_frequency" => {
|
||||
let freq: u32 = val.extract()?;
|
||||
config.set_min_frequency(freq);
|
||||
}
|
||||
_ => println!("Ignored unknown kwargs option {}", key),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Trainer {
|
||||
trainer: Container::Owned(Box::new(tk::models::bpe::BpeTrainer::new(config))),
|
||||
})
|
||||
}
|
||||
}
|
@ -1,3 +1,3 @@
|
||||
__version__ = "0.0.2"
|
||||
|
||||
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers
|
||||
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers
|
||||
|
Reference in New Issue
Block a user