mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Replace Model and Trainer Containers.
* Implement changes necessary from generic Model in Tokenizer. * Temporarily disable training in Python since Clone can't be derived for Model until all components have been replaced. * Prefix Python types in Rust with Py.
This commit is contained in:
committed by
Anthony MOI
parent
cdef780aa8
commit
83a52c8080
@ -1,27 +1,57 @@
|
||||
extern crate tokenizers as tk;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::utils::Container;
|
||||
use crate::tokenizer::AddedToken;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use tk::models::TrainerWrapper;
|
||||
use tk::Trainer;
|
||||
use tokenizers as tk;
|
||||
|
||||
#[pyclass]
|
||||
pub struct Trainer {
|
||||
pub trainer: Container<dyn tk::tokenizer::Trainer>,
|
||||
use crate::models::PyModel;
|
||||
use crate::tokenizer::PyAddedToken;
|
||||
|
||||
#[pyclass(name=Trainer)]
|
||||
pub struct PyTrainer {
|
||||
pub trainer: Arc<TrainerWrapper>,
|
||||
}
|
||||
|
||||
#[pyclass(extends=Trainer)]
|
||||
pub struct BpeTrainer {}
|
||||
impl PyTrainer {
|
||||
pub fn new(trainer: Arc<TrainerWrapper>) -> Self {
|
||||
PyTrainer { trainer }
|
||||
}
|
||||
}
|
||||
|
||||
impl Trainer for PyTrainer {
|
||||
type Model = PyModel;
|
||||
|
||||
fn should_show_progress(&self) -> bool {
|
||||
self.trainer.should_show_progress()
|
||||
}
|
||||
|
||||
fn train(&self, words: HashMap<String, u32>) -> tk::Result<(PyModel, Vec<tk::AddedToken>)> {
|
||||
self.trainer.train(words).map(|(m, t)| {
|
||||
let m = PyModel { model: Arc::new(m) };
|
||||
(m, t)
|
||||
})
|
||||
}
|
||||
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
self.trainer.process_tokens(words, tokens)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=PyTrainer, name=BpeTrainer)]
|
||||
pub struct PyBpeTrainer {}
|
||||
#[pymethods]
|
||||
impl BpeTrainer {
|
||||
impl PyBpeTrainer {
|
||||
/// new(/ vocab_size, min_frequency)
|
||||
/// --
|
||||
///
|
||||
/// Create a new BpeTrainer with the given configuration
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Trainer)> {
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||
let mut builder = tk::models::bpe::BpeTrainer::builder();
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, val) in kwargs {
|
||||
@ -36,9 +66,9 @@ impl BpeTrainer {
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(AddedToken::from(content, Some(true)).get_token())
|
||||
Ok(PyAddedToken::from(content, Some(true)).get_token())
|
||||
} else if let Ok(mut token) =
|
||||
token.extract::<PyRefMut<AddedToken>>()
|
||||
token.extract::<PyRefMut<PyAddedToken>>()
|
||||
{
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
@ -72,25 +102,23 @@ impl BpeTrainer {
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
BpeTrainer {},
|
||||
Trainer {
|
||||
trainer: Container::Owned(Box::new(builder.build())),
|
||||
},
|
||||
PyBpeTrainer {},
|
||||
PyTrainer::new(Arc::new(builder.build().into())),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=Trainer)]
|
||||
pub struct WordPieceTrainer {}
|
||||
#[pyclass(extends=PyTrainer, name=WordPieceTrainer)]
|
||||
pub struct PyWordPieceTrainer {}
|
||||
#[pymethods]
|
||||
impl WordPieceTrainer {
|
||||
impl PyWordPieceTrainer {
|
||||
/// new(/ vocab_size, min_frequency)
|
||||
/// --
|
||||
///
|
||||
/// Create a new BpeTrainer with the given configuration
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Trainer)> {
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||
let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, val) in kwargs {
|
||||
@ -105,9 +133,9 @@ impl WordPieceTrainer {
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(AddedToken::from(content, Some(true)).get_token())
|
||||
Ok(PyAddedToken::from(content, Some(true)).get_token())
|
||||
} else if let Ok(mut token) =
|
||||
token.extract::<PyRefMut<AddedToken>>()
|
||||
token.extract::<PyRefMut<PyAddedToken>>()
|
||||
{
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
@ -142,10 +170,8 @@ impl WordPieceTrainer {
|
||||
}
|
||||
|
||||
Ok((
|
||||
WordPieceTrainer {},
|
||||
Trainer {
|
||||
trainer: Container::Owned(Box::new(builder.build())),
|
||||
},
|
||||
PyWordPieceTrainer {},
|
||||
PyTrainer::new(Arc::new(builder.build().into())),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user