mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Add bindings for Trainer in Python
This commit is contained in:
@ -3,11 +3,20 @@ mod models;
|
|||||||
mod pre_tokenizers;
|
mod pre_tokenizers;
|
||||||
mod token;
|
mod token;
|
||||||
mod tokenizer;
|
mod tokenizer;
|
||||||
|
mod trainers;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::wrap_pymodule;
|
use pyo3::wrap_pymodule;
|
||||||
|
|
||||||
|
/// Trainers Module
|
||||||
|
#[pymodule]
|
||||||
|
fn trainers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
|
m.add_class::<trainers::Trainer>()?;
|
||||||
|
m.add_class::<trainers::BpeTrainer>()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Models Module
|
/// Models Module
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn models(_py: Python, m: &PyModule) -> PyResult<()> {
|
fn models(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
@ -39,5 +48,6 @@ fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add_wrapped(wrap_pymodule!(models))?;
|
m.add_wrapped(wrap_pymodule!(models))?;
|
||||||
m.add_wrapped(wrap_pymodule!(pre_tokenizers))?;
|
m.add_wrapped(wrap_pymodule!(pre_tokenizers))?;
|
||||||
m.add_wrapped(wrap_pymodule!(decoders))?;
|
m.add_wrapped(wrap_pymodule!(decoders))?;
|
||||||
|
m.add_wrapped(wrap_pymodule!(trainers))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -43,4 +43,11 @@ impl BPE {
|
|||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[staticmethod]
|
||||||
|
fn empty() -> Model {
|
||||||
|
Model {
|
||||||
|
model: Container::Owned(Box::new(tk::models::bpe::BPE::empty())),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ use super::decoders::Decoder;
|
|||||||
use super::models::Model;
|
use super::models::Model;
|
||||||
use super::pre_tokenizers::PreTokenizer;
|
use super::pre_tokenizers::PreTokenizer;
|
||||||
use super::token::Token;
|
use super::token::Token;
|
||||||
|
use super::trainers::Trainer;
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
pub struct Tokenizer {
|
pub struct Tokenizer {
|
||||||
@ -97,5 +98,15 @@ impl Tokenizer {
|
|||||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||||
self.tokenizer.id_to_token(id)
|
self.tokenizer.id_to_token(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn train(&mut self, trainer: &Trainer, files: Vec<String>) -> PyResult<()> {
|
||||||
|
trainer.trainer.execute(|trainer| {
|
||||||
|
if let Err(e) = self.tokenizer.train(trainer, files) {
|
||||||
|
Err(exceptions::Exception::py_err(format!("{}", e)))
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
45
bindings/python/src/trainers.rs
Normal file
45
bindings/python/src/trainers.rs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
|
use super::utils::Container;
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::types::*;
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
pub struct Trainer {
|
||||||
|
pub trainer: Container<dyn tk::tokenizer::Trainer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
pub struct BpeTrainer {}
|
||||||
|
#[pymethods]
|
||||||
|
impl BpeTrainer {
|
||||||
|
/// new(/vocab_size, min_frequency)
|
||||||
|
/// --
|
||||||
|
///
|
||||||
|
/// Create a new BpeTrainer with the given configuration
|
||||||
|
#[staticmethod]
|
||||||
|
#[args(kwargs = "**")]
|
||||||
|
pub fn new(kwargs: Option<&PyDict>) -> PyResult<Trainer> {
|
||||||
|
let mut config: tk::models::bpe::BpeTrainerConfig = Default::default();
|
||||||
|
if let Some(kwargs) = kwargs {
|
||||||
|
for (key, val) in kwargs {
|
||||||
|
let key: &str = key.extract()?;
|
||||||
|
match key {
|
||||||
|
"vocab_size" => {
|
||||||
|
let size: usize = val.extract()?;
|
||||||
|
config.set_vocab_size(size);
|
||||||
|
}
|
||||||
|
"min_frequency" => {
|
||||||
|
let freq: u32 = val.extract()?;
|
||||||
|
config.set_min_frequency(freq);
|
||||||
|
}
|
||||||
|
_ => println!("Ignored unknown kwargs option {}", key),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Trainer {
|
||||||
|
trainer: Container::Owned(Box::new(tk::models::bpe::BpeTrainer::new(config))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -1,3 +1,3 @@
|
|||||||
__version__ = "0.0.2"
|
__version__ = "0.0.2"
|
||||||
|
|
||||||
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers
|
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers
|
||||||
|
Reference in New Issue
Block a user