mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Python - Trainers can get/set their attributes
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
@ -10,6 +10,7 @@ use tokenizers as tk;
|
||||
|
||||
use crate::models::PyModel;
|
||||
use crate::tokenizer::PyAddedToken;
|
||||
use crate::utils::PyChar;
|
||||
|
||||
/// Base class for all trainers
|
||||
///
|
||||
@ -19,19 +20,15 @@ use crate::tokenizer::PyAddedToken;
|
||||
#[derive(Clone)]
|
||||
#[text_signature = "(self, vocab_size=30000, min_frequency=0,show_progress=True, special_tokens=[],limit_alphabet=None, initial_alphabet = [], continuing_subword_prefix=None, end_of_word_suffix=None)"]
|
||||
pub struct PyTrainer {
|
||||
pub trainer: Arc<TrainerWrapper>,
|
||||
pub trainer: Arc<RwLock<TrainerWrapper>>,
|
||||
}
|
||||
|
||||
impl PyTrainer {
|
||||
pub(crate) fn new(trainer: Arc<TrainerWrapper>) -> Self {
|
||||
PyTrainer { trainer }
|
||||
}
|
||||
|
||||
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
Ok(match self.trainer.as_ref() {
|
||||
Ok(match *self.trainer.as_ref().read().unwrap() {
|
||||
TrainerWrapper::BpeTrainer(_) => Py::new(py, (PyBpeTrainer {}, base))?.into_py(py),
|
||||
TrainerWrapper::WordPieceTrainer(_) => {
|
||||
Py::new(py, (PyWordPieceTrainer {}, base))?.into_py(py)
|
||||
@ -50,7 +47,7 @@ impl Trainer for PyTrainer {
|
||||
type Model = PyModel;
|
||||
|
||||
fn should_show_progress(&self) -> bool {
|
||||
self.trainer.should_show_progress()
|
||||
self.trainer.read().unwrap().should_show_progress()
|
||||
}
|
||||
|
||||
fn train(
|
||||
@ -58,11 +55,14 @@ impl Trainer for PyTrainer {
|
||||
words: HashMap<String, u32>,
|
||||
model: &mut PyModel,
|
||||
) -> tk::Result<Vec<tk::AddedToken>> {
|
||||
self.trainer.train(words, &mut model.model.write().unwrap())
|
||||
self.trainer
|
||||
.read()
|
||||
.unwrap()
|
||||
.train(words, &mut model.model.write().unwrap())
|
||||
}
|
||||
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
self.trainer.process_tokens(words, tokens)
|
||||
self.trainer.read().unwrap().process_tokens(words, tokens)
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,11 +72,37 @@ where
|
||||
{
|
||||
fn from(trainer: I) -> Self {
|
||||
PyTrainer {
|
||||
trainer: trainer.into().into(),
|
||||
trainer: Arc::new(RwLock::new(trainer.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! getter {
|
||||
($self: ident, $variant: ident, $($name: tt)+) => {{
|
||||
let super_ = $self.as_ref();
|
||||
if let TrainerWrapper::$variant(ref trainer) = *super_.trainer.read().unwrap() {
|
||||
trainer.$($name)+
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! setter {
|
||||
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
|
||||
let super_ = $self.as_ref();
|
||||
if let TrainerWrapper::$variant(ref mut trainer) = *super_.trainer.write().unwrap() {
|
||||
trainer.$name = $value;
|
||||
}
|
||||
}};
|
||||
($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
|
||||
let super_ = $self.as_ref();
|
||||
if let TrainerWrapper::$variant(ref mut trainer) = *super_.trainer.write().unwrap() {
|
||||
trainer.$name($value);
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
/// Trainer capable of training a BPE model
|
||||
///
|
||||
/// Args:
|
||||
@ -110,6 +136,122 @@ where
|
||||
pub struct PyBpeTrainer {}
|
||||
#[pymethods]
|
||||
impl PyBpeTrainer {
|
||||
#[getter]
|
||||
fn get_vocab_size(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, BpeTrainer, vocab_size)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
|
||||
setter!(self_, BpeTrainer, vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
||||
getter!(self_, BpeTrainer, min_frequency)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
||||
setter!(self_, BpeTrainer, min_frequency, freq);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_show_progress(self_: PyRef<Self>) -> bool {
|
||||
getter!(self_, BpeTrainer, show_progress)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
|
||||
setter!(self_, BpeTrainer, show_progress, show_progress);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
|
||||
getter!(
|
||||
self_,
|
||||
BpeTrainer,
|
||||
special_tokens
|
||||
.iter()
|
||||
.map(|tok| tok.clone().into())
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
|
||||
setter!(
|
||||
self_,
|
||||
BpeTrainer,
|
||||
special_tokens,
|
||||
special_tokens
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Special tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_limit_alphabet(self_: PyRef<Self>) -> Option<usize> {
|
||||
getter!(self_, BpeTrainer, limit_alphabet)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_limit_alphabet(self_: PyRef<Self>, limit: Option<usize>) {
|
||||
setter!(self_, BpeTrainer, limit_alphabet, limit);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
|
||||
getter!(
|
||||
self_,
|
||||
BpeTrainer,
|
||||
initial_alphabet.iter().map(|c| c.to_string()).collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
|
||||
setter!(
|
||||
self_,
|
||||
BpeTrainer,
|
||||
initial_alphabet,
|
||||
alphabet.into_iter().map(|c| c.0).collect()
|
||||
);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> Option<String> {
|
||||
getter!(self_, BpeTrainer, continuing_subword_prefix.clone())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_continuing_subword_prefix(self_: PyRef<Self>, prefix: Option<String>) {
|
||||
setter!(self_, BpeTrainer, continuing_subword_prefix, prefix);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_end_of_word_suffix(self_: PyRef<Self>) -> Option<String> {
|
||||
getter!(self_, BpeTrainer, end_of_word_suffix.clone())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_end_of_word_suffix(self_: PyRef<Self>, suffix: Option<String>) {
|
||||
setter!(self_, BpeTrainer, end_of_word_suffix, suffix);
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||
@ -162,10 +304,7 @@ impl PyBpeTrainer {
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
PyBpeTrainer {},
|
||||
PyTrainer::new(Arc::new(builder.build().into())),
|
||||
))
|
||||
Ok((PyBpeTrainer {}, builder.build().into()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,6 +342,122 @@ impl PyBpeTrainer {
|
||||
pub struct PyWordPieceTrainer {}
|
||||
#[pymethods]
|
||||
impl PyWordPieceTrainer {
|
||||
#[getter]
|
||||
fn get_vocab_size(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, WordPieceTrainer, vocab_size())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
|
||||
setter!(self_, WordPieceTrainer, @set_vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
||||
getter!(self_, WordPieceTrainer, min_frequency())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
||||
setter!(self_, WordPieceTrainer, @set_min_frequency, freq);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_show_progress(self_: PyRef<Self>) -> bool {
|
||||
getter!(self_, WordPieceTrainer, show_progress())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
|
||||
setter!(self_, WordPieceTrainer, @set_show_progress, show_progress);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
|
||||
getter!(
|
||||
self_,
|
||||
WordPieceTrainer,
|
||||
special_tokens()
|
||||
.iter()
|
||||
.map(|tok| tok.clone().into())
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
|
||||
setter!(
|
||||
self_,
|
||||
WordPieceTrainer,
|
||||
@set_special_tokens,
|
||||
special_tokens
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Special tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_limit_alphabet(self_: PyRef<Self>) -> Option<usize> {
|
||||
getter!(self_, WordPieceTrainer, limit_alphabet())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_limit_alphabet(self_: PyRef<Self>, limit: Option<usize>) {
|
||||
setter!(self_, WordPieceTrainer, @set_limit_alphabet, limit);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
|
||||
getter!(
|
||||
self_,
|
||||
WordPieceTrainer,
|
||||
initial_alphabet().iter().map(|c| c.to_string()).collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
|
||||
setter!(
|
||||
self_,
|
||||
WordPieceTrainer,
|
||||
@set_initial_alphabet,
|
||||
alphabet.into_iter().map(|c| c.0).collect()
|
||||
);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> Option<String> {
|
||||
getter!(self_, WordPieceTrainer, continuing_subword_prefix().clone())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_continuing_subword_prefix(self_: PyRef<Self>, prefix: Option<String>) {
|
||||
setter!(self_, WordPieceTrainer, @set_continuing_subword_prefix, prefix);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_end_of_word_suffix(self_: PyRef<Self>) -> Option<String> {
|
||||
getter!(self_, WordPieceTrainer, end_of_word_suffix().clone())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_end_of_word_suffix(self_: PyRef<Self>, suffix: Option<String>) {
|
||||
setter!(self_, WordPieceTrainer, @set_end_of_word_suffix, suffix);
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||
@ -256,10 +511,7 @@ impl PyWordPieceTrainer {
|
||||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
PyWordPieceTrainer {},
|
||||
PyTrainer::new(Arc::new(builder.build().into())),
|
||||
))
|
||||
Ok((PyWordPieceTrainer {}, builder.build().into()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -281,6 +533,73 @@ impl PyWordPieceTrainer {
|
||||
pub struct PyWordLevelTrainer {}
|
||||
#[pymethods]
|
||||
impl PyWordLevelTrainer {
|
||||
#[getter]
|
||||
fn get_vocab_size(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, WordLevelTrainer, vocab_size)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
|
||||
setter!(self_, WordLevelTrainer, vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
||||
getter!(self_, WordLevelTrainer, min_frequency)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
||||
setter!(self_, WordLevelTrainer, min_frequency, freq);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_show_progress(self_: PyRef<Self>) -> bool {
|
||||
getter!(self_, WordLevelTrainer, show_progress)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
|
||||
setter!(self_, WordLevelTrainer, show_progress, show_progress);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
|
||||
getter!(
|
||||
self_,
|
||||
WordLevelTrainer,
|
||||
special_tokens
|
||||
.iter()
|
||||
.map(|tok| tok.clone().into())
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
|
||||
setter!(
|
||||
self_,
|
||||
WordLevelTrainer,
|
||||
special_tokens,
|
||||
special_tokens
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Special tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||
@ -327,12 +646,10 @@ impl PyWordLevelTrainer {
|
||||
|
||||
Ok((
|
||||
PyWordLevelTrainer {},
|
||||
PyTrainer::new(Arc::new(
|
||||
builder
|
||||
.build()
|
||||
.expect("WordLevelTrainerBuilder cannot fail")
|
||||
.into(),
|
||||
)),
|
||||
builder
|
||||
.build()
|
||||
.expect("WordLevelTrainerBuilder cannot fail")
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
@ -359,6 +676,82 @@ impl PyWordLevelTrainer {
|
||||
pub struct PyUnigramTrainer {}
|
||||
#[pymethods]
|
||||
impl PyUnigramTrainer {
|
||||
#[getter]
|
||||
fn get_vocab_size(self_: PyRef<Self>) -> u32 {
|
||||
getter!(self_, UnigramTrainer, vocab_size)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_vocab_size(self_: PyRef<Self>, vocab_size: u32) {
|
||||
setter!(self_, UnigramTrainer, vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_show_progress(self_: PyRef<Self>) -> bool {
|
||||
getter!(self_, UnigramTrainer, show_progress)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
|
||||
setter!(self_, UnigramTrainer, show_progress, show_progress);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
|
||||
getter!(
|
||||
self_,
|
||||
UnigramTrainer,
|
||||
special_tokens
|
||||
.iter()
|
||||
.map(|tok| tok.clone().into())
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
|
||||
setter!(
|
||||
self_,
|
||||
UnigramTrainer,
|
||||
special_tokens,
|
||||
special_tokens
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Special tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
|
||||
getter!(
|
||||
self_,
|
||||
UnigramTrainer,
|
||||
initial_alphabet.iter().map(|c| c.to_string()).collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
|
||||
setter!(
|
||||
self_,
|
||||
UnigramTrainer,
|
||||
initial_alphabet,
|
||||
alphabet.into_iter().map(|c| c.0).collect()
|
||||
);
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||
@ -416,9 +809,6 @@ impl PyUnigramTrainer {
|
||||
builder.build().map_err(|e| {
|
||||
exceptions::PyException::new_err(format!("Cannot build UnigramTrainer: {}", e))
|
||||
})?;
|
||||
Ok((
|
||||
PyUnigramTrainer {},
|
||||
PyTrainer::new(Arc::new(trainer.into())),
|
||||
))
|
||||
Ok((PyUnigramTrainer {}, trainer.into()))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user