mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Python - Trainers can get/set their attributes
This commit is contained in:
@ -1,6 +1,8 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{hash_map::DefaultHasher, HashMap};
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
use numpy::PyArray1;
|
||||
use pyo3::class::basic::CompareOp;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
@ -106,6 +108,19 @@ impl PyAddedToken {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tk::AddedToken> for PyAddedToken {
|
||||
fn from(token: tk::AddedToken) -> Self {
|
||||
Self {
|
||||
content: token.content,
|
||||
single_word: Some(token.single_word),
|
||||
lstrip: Some(token.lstrip),
|
||||
rstrip: Some(token.rstrip),
|
||||
normalized: Some(token.normalized),
|
||||
is_special_token: !token.normalized,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyAddedToken {
|
||||
#[new]
|
||||
@ -205,6 +220,21 @@ impl PyObjectProtocol for PyAddedToken {
|
||||
bool_to_python(token.normalized)
|
||||
))
|
||||
}
|
||||
|
||||
fn __richcmp__(&self, other: Py<PyAddedToken>, op: CompareOp) -> bool {
|
||||
use CompareOp::*;
|
||||
Python::with_gil(|py| match op {
|
||||
Lt | Le | Gt | Ge => false,
|
||||
Eq => self.get_token() == other.borrow(py).get_token(),
|
||||
Ne => self.get_token() != other.borrow(py).get_token(),
|
||||
})
|
||||
}
|
||||
|
||||
fn __hash__(&self) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
self.get_token().hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
||||
|
||||
struct TextInputSequence<'s>(tk::InputSequence<'s>);
|
||||
|
@ -1,5 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
@ -10,6 +10,7 @@ use tokenizers as tk;
|
||||
|
||||
use crate::models::PyModel;
|
||||
use crate::tokenizer::PyAddedToken;
|
||||
use crate::utils::PyChar;
|
||||
|
||||
/// Base class for all trainers
|
||||
///
|
||||
@ -19,19 +20,15 @@ use crate::tokenizer::PyAddedToken;
|
||||
#[derive(Clone)]
|
||||
#[text_signature = "(self, vocab_size=30000, min_frequency=0,show_progress=True, special_tokens=[],limit_alphabet=None, initial_alphabet = [], continuing_subword_prefix=None, end_of_word_suffix=None)"]
|
||||
pub struct PyTrainer {
|
||||
pub trainer: Arc<TrainerWrapper>,
|
||||
pub trainer: Arc<RwLock<TrainerWrapper>>,
|
||||
}
|
||||
|
||||
impl PyTrainer {
|
||||
pub(crate) fn new(trainer: Arc<TrainerWrapper>) -> Self {
|
||||
PyTrainer { trainer }
|
||||
}
|
||||
|
||||
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
Ok(match self.trainer.as_ref() {
|
||||
Ok(match *self.trainer.as_ref().read().unwrap() {
|
||||
TrainerWrapper::BpeTrainer(_) => Py::new(py, (PyBpeTrainer {}, base))?.into_py(py),
|
||||
TrainerWrapper::WordPieceTrainer(_) => {
|
||||
Py::new(py, (PyWordPieceTrainer {}, base))?.into_py(py)
|
||||
@ -50,7 +47,7 @@ impl Trainer for PyTrainer {
|
||||
type Model = PyModel;
|
||||
|
||||
fn should_show_progress(&self) -> bool {
|
||||
self.trainer.should_show_progress()
|
||||
self.trainer.read().unwrap().should_show_progress()
|
||||
}
|
||||
|
||||
fn train(
|
||||
@ -58,11 +55,14 @@ impl Trainer for PyTrainer {
|
||||
words: HashMap<String, u32>,
|
||||
model: &mut PyModel,
|
||||
) -> tk::Result<Vec<tk::AddedToken>> {
|
||||
self.trainer.train(words, &mut model.model.write().unwrap())
|
||||
self.trainer
|
||||
.read()
|
||||
.unwrap()
|
||||
.train(words, &mut model.model.write().unwrap())
|
||||
}
|
||||
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
self.trainer.process_tokens(words, tokens)
|
||||
self.trainer.read().unwrap().process_tokens(words, tokens)
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,11 +72,37 @@ where
|
||||
{
|
||||
fn from(trainer: I) -> Self {
|
||||
PyTrainer {
|
||||
trainer: trainer.into().into(),
|
||||
trainer: Arc::new(RwLock::new(trainer.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! getter {
|
||||
($self: ident, $variant: ident, $($name: tt)+) => {{
|
||||
let super_ = $self.as_ref();
|
||||
if let TrainerWrapper::$variant(ref trainer) = *super_.trainer.read().unwrap() {
|
||||
trainer.$($name)+
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! setter {
|
||||
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
|
||||
let super_ = $self.as_ref();
|
||||
if let TrainerWrapper::$variant(ref mut trainer) = *super_.trainer.write().unwrap() {
|
||||
trainer.$name = $value;
|
||||
}
|
||||
}};
|
||||
($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
|
||||
let super_ = $self.as_ref();
|
||||
if let TrainerWrapper::$variant(ref mut trainer) = *super_.trainer.write().unwrap() {
|
||||
trainer.$name($value);
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
/// Trainer capable of training a BPE model
|
||||
///
|
||||
/// Args:
|
||||
@ -110,6 +136,122 @@ where
|
||||
pub struct PyBpeTrainer {}
|
||||
#[pymethods]
|
||||
impl PyBpeTrainer {
|
||||
#[getter]
|
||||
fn get_vocab_size(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, BpeTrainer, vocab_size)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
|
||||
setter!(self_, BpeTrainer, vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
||||
getter!(self_, BpeTrainer, min_frequency)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
||||
setter!(self_, BpeTrainer, min_frequency, freq);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_show_progress(self_: PyRef<Self>) -> bool {
|
||||
getter!(self_, BpeTrainer, show_progress)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
|
||||
setter!(self_, BpeTrainer, show_progress, show_progress);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
|
||||
getter!(
|
||||
self_,
|
||||
BpeTrainer,
|
||||
special_tokens
|
||||
.iter()
|
||||
.map(|tok| tok.clone().into())
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
|
||||
setter!(
|
||||
self_,
|
||||
BpeTrainer,
|
||||
special_tokens,
|
||||
special_tokens
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Special tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_limit_alphabet(self_: PyRef<Self>) -> Option<usize> {
|
||||
getter!(self_, BpeTrainer, limit_alphabet)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_limit_alphabet(self_: PyRef<Self>, limit: Option<usize>) {
|
||||
setter!(self_, BpeTrainer, limit_alphabet, limit);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
|
||||
getter!(
|
||||
self_,
|
||||
BpeTrainer,
|
||||
initial_alphabet.iter().map(|c| c.to_string()).collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
|
||||
setter!(
|
||||
self_,
|
||||
BpeTrainer,
|
||||
initial_alphabet,
|
||||
alphabet.into_iter().map(|c| c.0).collect()
|
||||
);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> Option<String> {
|
||||
getter!(self_, BpeTrainer, continuing_subword_prefix.clone())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_continuing_subword_prefix(self_: PyRef<Self>, prefix: Option<String>) {
|
||||
setter!(self_, BpeTrainer, continuing_subword_prefix, prefix);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_end_of_word_suffix(self_: PyRef<Self>) -> Option<String> {
|
||||
getter!(self_, BpeTrainer, end_of_word_suffix.clone())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_end_of_word_suffix(self_: PyRef<Self>, suffix: Option<String>) {
|
||||
setter!(self_, BpeTrainer, end_of_word_suffix, suffix);
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||
@ -162,10 +304,7 @@ impl PyBpeTrainer {
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
PyBpeTrainer {},
|
||||
PyTrainer::new(Arc::new(builder.build().into())),
|
||||
))
|
||||
Ok((PyBpeTrainer {}, builder.build().into()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,6 +342,122 @@ impl PyBpeTrainer {
|
||||
pub struct PyWordPieceTrainer {}
|
||||
#[pymethods]
|
||||
impl PyWordPieceTrainer {
|
||||
#[getter]
|
||||
fn get_vocab_size(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, WordPieceTrainer, vocab_size())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
|
||||
setter!(self_, WordPieceTrainer, @set_vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
||||
getter!(self_, WordPieceTrainer, min_frequency())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
||||
setter!(self_, WordPieceTrainer, @set_min_frequency, freq);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_show_progress(self_: PyRef<Self>) -> bool {
|
||||
getter!(self_, WordPieceTrainer, show_progress())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
|
||||
setter!(self_, WordPieceTrainer, @set_show_progress, show_progress);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
|
||||
getter!(
|
||||
self_,
|
||||
WordPieceTrainer,
|
||||
special_tokens()
|
||||
.iter()
|
||||
.map(|tok| tok.clone().into())
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
|
||||
setter!(
|
||||
self_,
|
||||
WordPieceTrainer,
|
||||
@set_special_tokens,
|
||||
special_tokens
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Special tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_limit_alphabet(self_: PyRef<Self>) -> Option<usize> {
|
||||
getter!(self_, WordPieceTrainer, limit_alphabet())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_limit_alphabet(self_: PyRef<Self>, limit: Option<usize>) {
|
||||
setter!(self_, WordPieceTrainer, @set_limit_alphabet, limit);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
|
||||
getter!(
|
||||
self_,
|
||||
WordPieceTrainer,
|
||||
initial_alphabet().iter().map(|c| c.to_string()).collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
|
||||
setter!(
|
||||
self_,
|
||||
WordPieceTrainer,
|
||||
@set_initial_alphabet,
|
||||
alphabet.into_iter().map(|c| c.0).collect()
|
||||
);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> Option<String> {
|
||||
getter!(self_, WordPieceTrainer, continuing_subword_prefix().clone())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_continuing_subword_prefix(self_: PyRef<Self>, prefix: Option<String>) {
|
||||
setter!(self_, WordPieceTrainer, @set_continuing_subword_prefix, prefix);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_end_of_word_suffix(self_: PyRef<Self>) -> Option<String> {
|
||||
getter!(self_, WordPieceTrainer, end_of_word_suffix().clone())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_end_of_word_suffix(self_: PyRef<Self>, suffix: Option<String>) {
|
||||
setter!(self_, WordPieceTrainer, @set_end_of_word_suffix, suffix);
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||
@ -256,10 +511,7 @@ impl PyWordPieceTrainer {
|
||||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
PyWordPieceTrainer {},
|
||||
PyTrainer::new(Arc::new(builder.build().into())),
|
||||
))
|
||||
Ok((PyWordPieceTrainer {}, builder.build().into()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -281,6 +533,73 @@ impl PyWordPieceTrainer {
|
||||
pub struct PyWordLevelTrainer {}
|
||||
#[pymethods]
|
||||
impl PyWordLevelTrainer {
|
||||
#[getter]
|
||||
fn get_vocab_size(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, WordLevelTrainer, vocab_size)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
|
||||
setter!(self_, WordLevelTrainer, vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
||||
getter!(self_, WordLevelTrainer, min_frequency)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
||||
setter!(self_, WordLevelTrainer, min_frequency, freq);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_show_progress(self_: PyRef<Self>) -> bool {
|
||||
getter!(self_, WordLevelTrainer, show_progress)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
|
||||
setter!(self_, WordLevelTrainer, show_progress, show_progress);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
|
||||
getter!(
|
||||
self_,
|
||||
WordLevelTrainer,
|
||||
special_tokens
|
||||
.iter()
|
||||
.map(|tok| tok.clone().into())
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
|
||||
setter!(
|
||||
self_,
|
||||
WordLevelTrainer,
|
||||
special_tokens,
|
||||
special_tokens
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Special tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||
@ -327,12 +646,10 @@ impl PyWordLevelTrainer {
|
||||
|
||||
Ok((
|
||||
PyWordLevelTrainer {},
|
||||
PyTrainer::new(Arc::new(
|
||||
builder
|
||||
.build()
|
||||
.expect("WordLevelTrainerBuilder cannot fail")
|
||||
.into(),
|
||||
)),
|
||||
builder
|
||||
.build()
|
||||
.expect("WordLevelTrainerBuilder cannot fail")
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
@ -359,6 +676,82 @@ impl PyWordLevelTrainer {
|
||||
pub struct PyUnigramTrainer {}
|
||||
#[pymethods]
|
||||
impl PyUnigramTrainer {
|
||||
#[getter]
|
||||
fn get_vocab_size(self_: PyRef<Self>) -> u32 {
|
||||
getter!(self_, UnigramTrainer, vocab_size)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_vocab_size(self_: PyRef<Self>, vocab_size: u32) {
|
||||
setter!(self_, UnigramTrainer, vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_show_progress(self_: PyRef<Self>) -> bool {
|
||||
getter!(self_, UnigramTrainer, show_progress)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
|
||||
setter!(self_, UnigramTrainer, show_progress, show_progress);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
|
||||
getter!(
|
||||
self_,
|
||||
UnigramTrainer,
|
||||
special_tokens
|
||||
.iter()
|
||||
.map(|tok| tok.clone().into())
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
|
||||
setter!(
|
||||
self_,
|
||||
UnigramTrainer,
|
||||
special_tokens,
|
||||
special_tokens
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"Special tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
|
||||
getter!(
|
||||
self_,
|
||||
UnigramTrainer,
|
||||
initial_alphabet.iter().map(|c| c.to_string()).collect()
|
||||
)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
|
||||
setter!(
|
||||
self_,
|
||||
UnigramTrainer,
|
||||
initial_alphabet,
|
||||
alphabet.into_iter().map(|c| c.0).collect()
|
||||
);
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||
@ -416,9 +809,6 @@ impl PyUnigramTrainer {
|
||||
builder.build().map_err(|e| {
|
||||
exceptions::PyException::new_err(format!("Cannot build UnigramTrainer: {}", e))
|
||||
})?;
|
||||
Ok((
|
||||
PyUnigramTrainer {},
|
||||
PyTrainer::new(Arc::new(trainer.into())),
|
||||
))
|
||||
Ok((PyUnigramTrainer {}, trainer.into()))
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import pickle
|
||||
|
||||
from tokenizers import (
|
||||
SentencePieceUnigramTokenizer,
|
||||
AddedToken,
|
||||
models,
|
||||
pre_tokenizers,
|
||||
normalizers,
|
||||
@ -13,6 +14,119 @@ from tokenizers import (
|
||||
from ..utils import data_dir, train_files
|
||||
|
||||
|
||||
class TestBPETrainer:
|
||||
def test_can_modify(self):
|
||||
trainer = trainers.BpeTrainer(
|
||||
vocab_size=12345,
|
||||
min_frequency=12,
|
||||
show_progress=False,
|
||||
special_tokens=["1", "2"],
|
||||
limit_alphabet=13,
|
||||
initial_alphabet=["a", "b", "c"],
|
||||
continuing_subword_prefix="pref",
|
||||
end_of_word_suffix="suf",
|
||||
)
|
||||
|
||||
assert trainer.vocab_size == 12345
|
||||
assert trainer.min_frequency == 12
|
||||
assert trainer.show_progress == False
|
||||
assert trainer.special_tokens == [
|
||||
AddedToken("1"),
|
||||
AddedToken("2"),
|
||||
]
|
||||
assert trainer.limit_alphabet == 13
|
||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||
assert trainer.continuing_subword_prefix == "pref"
|
||||
assert trainer.end_of_word_suffix == "suf"
|
||||
|
||||
# Modify these
|
||||
trainer.vocab_size = 20000
|
||||
assert trainer.vocab_size == 20000
|
||||
trainer.min_frequency = 1
|
||||
assert trainer.min_frequency == 1
|
||||
trainer.show_progress = True
|
||||
assert trainer.show_progress == True
|
||||
trainer.special_tokens = []
|
||||
assert trainer.special_tokens == []
|
||||
trainer.limit_alphabet = None
|
||||
assert trainer.limit_alphabet == None
|
||||
trainer.initial_alphabet = ["d", "z"]
|
||||
assert sorted(trainer.initial_alphabet) == ["d", "z"]
|
||||
trainer.continuing_subword_prefix = None
|
||||
assert trainer.continuing_subword_prefix == None
|
||||
trainer.end_of_word_suffix = None
|
||||
assert trainer.continuing_subword_prefix == None
|
||||
|
||||
|
||||
class TestWordPieceTrainer:
|
||||
def test_can_modify(self):
|
||||
trainer = trainers.WordPieceTrainer(
|
||||
vocab_size=12345,
|
||||
min_frequency=12,
|
||||
show_progress=False,
|
||||
special_tokens=["1", "2"],
|
||||
limit_alphabet=13,
|
||||
initial_alphabet=["a", "b", "c"],
|
||||
continuing_subword_prefix="pref",
|
||||
end_of_word_suffix="suf",
|
||||
)
|
||||
|
||||
assert trainer.vocab_size == 12345
|
||||
assert trainer.min_frequency == 12
|
||||
assert trainer.show_progress == False
|
||||
assert trainer.special_tokens == [
|
||||
AddedToken("1"),
|
||||
AddedToken("2"),
|
||||
]
|
||||
assert trainer.limit_alphabet == 13
|
||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||
assert trainer.continuing_subword_prefix == "pref"
|
||||
assert trainer.end_of_word_suffix == "suf"
|
||||
|
||||
# Modify these
|
||||
trainer.vocab_size = 20000
|
||||
assert trainer.vocab_size == 20000
|
||||
trainer.min_frequency = 1
|
||||
assert trainer.min_frequency == 1
|
||||
trainer.show_progress = True
|
||||
assert trainer.show_progress == True
|
||||
trainer.special_tokens = []
|
||||
assert trainer.special_tokens == []
|
||||
trainer.limit_alphabet = None
|
||||
assert trainer.limit_alphabet == None
|
||||
trainer.initial_alphabet = ["d", "z"]
|
||||
assert sorted(trainer.initial_alphabet) == ["d", "z"]
|
||||
trainer.continuing_subword_prefix = None
|
||||
assert trainer.continuing_subword_prefix == None
|
||||
trainer.end_of_word_suffix = None
|
||||
assert trainer.continuing_subword_prefix == None
|
||||
|
||||
|
||||
class TestWordLevelTrainer:
|
||||
def test_can_modify(self):
|
||||
trainer = trainers.WordLevelTrainer(
|
||||
vocab_size=12345, min_frequency=12, show_progress=False, special_tokens=["1", "2"]
|
||||
)
|
||||
|
||||
assert trainer.vocab_size == 12345
|
||||
assert trainer.min_frequency == 12
|
||||
assert trainer.show_progress == False
|
||||
assert trainer.special_tokens == [
|
||||
AddedToken("1"),
|
||||
AddedToken("2"),
|
||||
]
|
||||
|
||||
# Modify these
|
||||
trainer.vocab_size = 20000
|
||||
assert trainer.vocab_size == 20000
|
||||
trainer.min_frequency = 1
|
||||
assert trainer.min_frequency == 1
|
||||
trainer.show_progress = True
|
||||
assert trainer.show_progress == True
|
||||
trainer.special_tokens = []
|
||||
assert trainer.special_tokens == []
|
||||
|
||||
|
||||
class TestUnigram:
|
||||
def test_train(self, train_files):
|
||||
tokenizer = SentencePieceUnigramTokenizer()
|
||||
@ -99,3 +213,29 @@ class TestUnigram:
|
||||
|
||||
with pytest.raises(Exception, match="UnigramTrainer can only train a Unigram"):
|
||||
tokenizer.train([], trainer)
|
||||
|
||||
def test_can_modify(self):
|
||||
trainer = trainers.UnigramTrainer(
|
||||
vocab_size=12345,
|
||||
show_progress=False,
|
||||
special_tokens=["1", AddedToken("2", lstrip=True)],
|
||||
initial_alphabet=["a", "b", "c"],
|
||||
)
|
||||
|
||||
assert trainer.vocab_size == 12345
|
||||
assert trainer.show_progress == False
|
||||
assert trainer.special_tokens == [
|
||||
AddedToken("1", normalized=False),
|
||||
AddedToken("2", lstrip=True, normalized=False),
|
||||
]
|
||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||
|
||||
# Modify these
|
||||
trainer.vocab_size = 20000
|
||||
assert trainer.vocab_size == 20000
|
||||
trainer.show_progress = True
|
||||
assert trainer.show_progress == True
|
||||
trainer.special_tokens = []
|
||||
assert trainer.special_tokens == []
|
||||
trainer.initial_alphabet = ["d", "z"]
|
||||
assert sorted(trainer.initial_alphabet) == ["d", "z"]
|
||||
|
@ -154,24 +154,26 @@ impl BpeTrainerBuilder {
|
||||
/// let mut model = BPE::default();
|
||||
/// let special_tokens = trainer.train(word_counts, &mut model).unwrap();
|
||||
/// ```
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct BpeTrainer {
|
||||
/// The minimum frequency a pair must have to produce a merge operation
|
||||
min_frequency: u32,
|
||||
pub min_frequency: u32,
|
||||
/// The target vocabulary size
|
||||
vocab_size: usize,
|
||||
pub vocab_size: usize,
|
||||
/// Whether to show progress while training
|
||||
show_progress: bool,
|
||||
pub show_progress: bool,
|
||||
/// A list of special tokens that the model should know of
|
||||
special_tokens: Vec<AddedToken>,
|
||||
pub special_tokens: Vec<AddedToken>,
|
||||
/// Whether to limit the number of initial tokens that can be kept before computing merges
|
||||
limit_alphabet: Option<usize>,
|
||||
pub limit_alphabet: Option<usize>,
|
||||
/// The initial alphabet we want absolutely to include. This allows to cover
|
||||
/// some characters that are not necessarily in the training set
|
||||
initial_alphabet: HashSet<char>,
|
||||
pub initial_alphabet: HashSet<char>,
|
||||
/// An optional prefix to use on any subword that exist only behind another one
|
||||
continuing_subword_prefix: Option<String>,
|
||||
pub continuing_subword_prefix: Option<String>,
|
||||
/// An optional suffix to caracterize and end-of-word subword
|
||||
end_of_word_suffix: Option<String>,
|
||||
pub end_of_word_suffix: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for BpeTrainer {
|
||||
|
@ -36,28 +36,29 @@ fn to_log_prob(pieces: &mut [SentencePiece]) {
|
||||
}
|
||||
|
||||
/// A `UnigramTrainer` can train a `Unigram` model from `word_counts`.
|
||||
#[non_exhaustive]
|
||||
#[derive(Builder, Debug, Clone)]
|
||||
pub struct UnigramTrainer {
|
||||
#[builder(default = "true")]
|
||||
show_progress: bool,
|
||||
pub show_progress: bool,
|
||||
#[builder(default = "8000")]
|
||||
vocab_size: u32,
|
||||
pub vocab_size: u32,
|
||||
#[builder(default = "2")]
|
||||
n_sub_iterations: u32,
|
||||
pub n_sub_iterations: u32,
|
||||
#[builder(default = "0.75")]
|
||||
shrinking_factor: f64,
|
||||
pub shrinking_factor: f64,
|
||||
#[builder(default = "vec![]")]
|
||||
special_tokens: Vec<AddedToken>,
|
||||
pub special_tokens: Vec<AddedToken>,
|
||||
#[builder(default = "HashSet::new()")]
|
||||
initial_alphabet: HashSet<char>,
|
||||
pub initial_alphabet: HashSet<char>,
|
||||
|
||||
#[builder(default = "None")]
|
||||
unk_token: Option<String>,
|
||||
pub unk_token: Option<String>,
|
||||
|
||||
#[builder(default = "16")]
|
||||
max_piece_length: usize,
|
||||
pub max_piece_length: usize,
|
||||
#[builder(default = "1_000_000")]
|
||||
seed_size: usize,
|
||||
pub seed_size: usize,
|
||||
}
|
||||
|
||||
impl Default for UnigramTrainer {
|
||||
|
@ -2,20 +2,21 @@ use super::WordLevel;
|
||||
use crate::{AddedToken, Result, Trainer};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, Clone, Builder)]
|
||||
pub struct WordLevelTrainer {
|
||||
/// The minimum frequency a word must have to be part of the vocabulary
|
||||
#[builder(default)]
|
||||
min_frequency: u32,
|
||||
pub min_frequency: u32,
|
||||
/// The target vocabulary size
|
||||
#[builder(default)]
|
||||
vocab_size: usize,
|
||||
pub vocab_size: usize,
|
||||
/// Whether to show progress while training
|
||||
#[builder(default)]
|
||||
show_progress: bool,
|
||||
pub show_progress: bool,
|
||||
/// A list of special tokens that the model should know of
|
||||
#[builder(default)]
|
||||
special_tokens: Vec<AddedToken>,
|
||||
pub special_tokens: Vec<AddedToken>,
|
||||
}
|
||||
|
||||
impl Default for WordLevelTrainer {
|
||||
|
@ -85,6 +85,70 @@ pub struct WordPieceTrainer {
|
||||
}
|
||||
|
||||
impl WordPieceTrainer {
|
||||
pub fn min_frequency(&self) -> u32 {
|
||||
self.bpe_trainer.min_frequency
|
||||
}
|
||||
|
||||
pub fn set_min_frequency(&mut self, freq: u32) {
|
||||
self.bpe_trainer.min_frequency = freq;
|
||||
}
|
||||
|
||||
pub fn vocab_size(&self) -> usize {
|
||||
self.bpe_trainer.vocab_size
|
||||
}
|
||||
|
||||
pub fn set_vocab_size(&mut self, size: usize) {
|
||||
self.bpe_trainer.vocab_size = size;
|
||||
}
|
||||
|
||||
pub fn show_progress(&self) -> bool {
|
||||
self.bpe_trainer.show_progress
|
||||
}
|
||||
|
||||
pub fn set_show_progress(&mut self, show_progress: bool) {
|
||||
self.bpe_trainer.show_progress = show_progress;
|
||||
}
|
||||
|
||||
pub fn special_tokens(&self) -> &[AddedToken] {
|
||||
&self.bpe_trainer.special_tokens
|
||||
}
|
||||
|
||||
pub fn set_special_tokens(&mut self, special_tokens: Vec<AddedToken>) {
|
||||
self.bpe_trainer.special_tokens = special_tokens;
|
||||
}
|
||||
|
||||
pub fn limit_alphabet(&self) -> Option<usize> {
|
||||
self.bpe_trainer.limit_alphabet
|
||||
}
|
||||
|
||||
pub fn set_limit_alphabet(&mut self, limit: Option<usize>) {
|
||||
self.bpe_trainer.limit_alphabet = limit;
|
||||
}
|
||||
|
||||
pub fn initial_alphabet(&self) -> &HashSet<char> {
|
||||
&self.bpe_trainer.initial_alphabet
|
||||
}
|
||||
|
||||
pub fn set_initial_alphabet(&mut self, alphabet: HashSet<char>) {
|
||||
self.bpe_trainer.initial_alphabet = alphabet;
|
||||
}
|
||||
|
||||
pub fn continuing_subword_prefix(&self) -> &Option<String> {
|
||||
&self.bpe_trainer.continuing_subword_prefix
|
||||
}
|
||||
|
||||
pub fn set_continuing_subword_prefix(&mut self, prefix: Option<String>) {
|
||||
self.bpe_trainer.continuing_subword_prefix = prefix;
|
||||
}
|
||||
|
||||
pub fn end_of_word_suffix(&self) -> &Option<String> {
|
||||
&self.bpe_trainer.end_of_word_suffix
|
||||
}
|
||||
|
||||
pub fn set_end_of_word_suffix(&mut self, suffix: Option<String>) {
|
||||
self.bpe_trainer.end_of_word_suffix = suffix;
|
||||
}
|
||||
|
||||
pub fn builder() -> WordPieceTrainerBuilder {
|
||||
WordPieceTrainerBuilder::default()
|
||||
}
|
||||
|
Reference in New Issue
Block a user