Python - Trainers can get/set their attributes

This commit is contained in:
Anthony MOI
2020-11-24 17:46:58 -05:00
committed by Anthony MOI
parent 3eb7ef6d0a
commit a351d1c604
7 changed files with 679 additions and 51 deletions

View File

@ -1,6 +1,8 @@
use std::collections::HashMap;
use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};
use numpy::PyArray1;
use pyo3::class::basic::CompareOp;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
@ -106,6 +108,19 @@ impl PyAddedToken {
}
}
impl From<tk::AddedToken> for PyAddedToken {
fn from(token: tk::AddedToken) -> Self {
Self {
content: token.content,
single_word: Some(token.single_word),
lstrip: Some(token.lstrip),
rstrip: Some(token.rstrip),
normalized: Some(token.normalized),
is_special_token: !token.normalized,
}
}
}
#[pymethods]
impl PyAddedToken {
#[new]
@ -205,6 +220,21 @@ impl PyObjectProtocol for PyAddedToken {
bool_to_python(token.normalized)
))
}
fn __richcmp__(&self, other: Py<PyAddedToken>, op: CompareOp) -> bool {
use CompareOp::*;
Python::with_gil(|py| match op {
Lt | Le | Gt | Ge => false,
Eq => self.get_token() == other.borrow(py).get_token(),
Ne => self.get_token() != other.borrow(py).get_token(),
})
}
fn __hash__(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.get_token().hash(&mut hasher);
hasher.finish()
}
}
struct TextInputSequence<'s>(tk::InputSequence<'s>);

View File

@ -1,5 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use pyo3::exceptions;
use pyo3::prelude::*;
@ -10,6 +10,7 @@ use tokenizers as tk;
use crate::models::PyModel;
use crate::tokenizer::PyAddedToken;
use crate::utils::PyChar;
/// Base class for all trainers
///
@ -19,19 +20,15 @@ use crate::tokenizer::PyAddedToken;
#[derive(Clone)]
#[text_signature = "(self, vocab_size=30000, min_frequency=0,show_progress=True, special_tokens=[],limit_alphabet=None, initial_alphabet = [], continuing_subword_prefix=None, end_of_word_suffix=None)"]
pub struct PyTrainer {
pub trainer: Arc<TrainerWrapper>,
pub trainer: Arc<RwLock<TrainerWrapper>>,
}
impl PyTrainer {
pub(crate) fn new(trainer: Arc<TrainerWrapper>) -> Self {
PyTrainer { trainer }
}
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
Ok(match self.trainer.as_ref() {
Ok(match *self.trainer.as_ref().read().unwrap() {
TrainerWrapper::BpeTrainer(_) => Py::new(py, (PyBpeTrainer {}, base))?.into_py(py),
TrainerWrapper::WordPieceTrainer(_) => {
Py::new(py, (PyWordPieceTrainer {}, base))?.into_py(py)
@ -50,7 +47,7 @@ impl Trainer for PyTrainer {
type Model = PyModel;
fn should_show_progress(&self) -> bool {
self.trainer.should_show_progress()
self.trainer.read().unwrap().should_show_progress()
}
fn train(
@ -58,11 +55,14 @@ impl Trainer for PyTrainer {
words: HashMap<String, u32>,
model: &mut PyModel,
) -> tk::Result<Vec<tk::AddedToken>> {
self.trainer.train(words, &mut model.model.write().unwrap())
self.trainer
.read()
.unwrap()
.train(words, &mut model.model.write().unwrap())
}
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
self.trainer.process_tokens(words, tokens)
self.trainer.read().unwrap().process_tokens(words, tokens)
}
}
@ -72,11 +72,37 @@ where
{
fn from(trainer: I) -> Self {
PyTrainer {
trainer: trainer.into().into(),
trainer: Arc::new(RwLock::new(trainer.into())),
}
}
}
macro_rules! getter {
($self: ident, $variant: ident, $($name: tt)+) => {{
let super_ = $self.as_ref();
if let TrainerWrapper::$variant(ref trainer) = *super_.trainer.read().unwrap() {
trainer.$($name)+
} else {
unreachable!()
}
}};
}
macro_rules! setter {
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
let super_ = $self.as_ref();
if let TrainerWrapper::$variant(ref mut trainer) = *super_.trainer.write().unwrap() {
trainer.$name = $value;
}
}};
($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
let super_ = $self.as_ref();
if let TrainerWrapper::$variant(ref mut trainer) = *super_.trainer.write().unwrap() {
trainer.$name($value);
}
}};
}
/// Trainer capable of training a BPE model
///
/// Args:
@ -110,6 +136,122 @@ where
pub struct PyBpeTrainer {}
#[pymethods]
impl PyBpeTrainer {
#[getter]
fn get_vocab_size(self_: PyRef<Self>) -> usize {
getter!(self_, BpeTrainer, vocab_size)
}
#[setter]
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
setter!(self_, BpeTrainer, vocab_size, vocab_size);
}
#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
getter!(self_, BpeTrainer, min_frequency)
}
#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
setter!(self_, BpeTrainer, min_frequency, freq);
}
#[getter]
fn get_show_progress(self_: PyRef<Self>) -> bool {
getter!(self_, BpeTrainer, show_progress)
}
#[setter]
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
setter!(self_, BpeTrainer, show_progress, show_progress);
}
#[getter]
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
getter!(
self_,
BpeTrainer,
special_tokens
.iter()
.map(|tok| tok.clone().into())
.collect()
)
}
#[setter]
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
setter!(
self_,
BpeTrainer,
special_tokens,
special_tokens
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
"Special tokens must be a List[Union[str, AddedToken]]",
))
}
})
.collect::<PyResult<Vec<_>>>()?
);
Ok(())
}
#[getter]
fn get_limit_alphabet(self_: PyRef<Self>) -> Option<usize> {
getter!(self_, BpeTrainer, limit_alphabet)
}
#[setter]
fn set_limit_alphabet(self_: PyRef<Self>, limit: Option<usize>) {
setter!(self_, BpeTrainer, limit_alphabet, limit);
}
#[getter]
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
getter!(
self_,
BpeTrainer,
initial_alphabet.iter().map(|c| c.to_string()).collect()
)
}
#[setter]
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
setter!(
self_,
BpeTrainer,
initial_alphabet,
alphabet.into_iter().map(|c| c.0).collect()
);
}
#[getter]
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> Option<String> {
getter!(self_, BpeTrainer, continuing_subword_prefix.clone())
}
#[setter]
fn set_continuing_subword_prefix(self_: PyRef<Self>, prefix: Option<String>) {
setter!(self_, BpeTrainer, continuing_subword_prefix, prefix);
}
#[getter]
fn get_end_of_word_suffix(self_: PyRef<Self>) -> Option<String> {
getter!(self_, BpeTrainer, end_of_word_suffix.clone())
}
#[setter]
fn set_end_of_word_suffix(self_: PyRef<Self>, suffix: Option<String>) {
setter!(self_, BpeTrainer, end_of_word_suffix, suffix);
}
#[new]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
@ -162,10 +304,7 @@ impl PyBpeTrainer {
};
}
}
Ok((
PyBpeTrainer {},
PyTrainer::new(Arc::new(builder.build().into())),
))
Ok((PyBpeTrainer {}, builder.build().into()))
}
}
@ -203,6 +342,122 @@ impl PyBpeTrainer {
pub struct PyWordPieceTrainer {}
#[pymethods]
impl PyWordPieceTrainer {
#[getter]
fn get_vocab_size(self_: PyRef<Self>) -> usize {
getter!(self_, WordPieceTrainer, vocab_size())
}
#[setter]
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
setter!(self_, WordPieceTrainer, @set_vocab_size, vocab_size);
}
#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
getter!(self_, WordPieceTrainer, min_frequency())
}
#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
setter!(self_, WordPieceTrainer, @set_min_frequency, freq);
}
#[getter]
fn get_show_progress(self_: PyRef<Self>) -> bool {
getter!(self_, WordPieceTrainer, show_progress())
}
#[setter]
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
setter!(self_, WordPieceTrainer, @set_show_progress, show_progress);
}
#[getter]
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
getter!(
self_,
WordPieceTrainer,
special_tokens()
.iter()
.map(|tok| tok.clone().into())
.collect()
)
}
#[setter]
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
setter!(
self_,
WordPieceTrainer,
@set_special_tokens,
special_tokens
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
"Special tokens must be a List[Union[str, AddedToken]]",
))
}
})
.collect::<PyResult<Vec<_>>>()?
);
Ok(())
}
#[getter]
fn get_limit_alphabet(self_: PyRef<Self>) -> Option<usize> {
getter!(self_, WordPieceTrainer, limit_alphabet())
}
#[setter]
fn set_limit_alphabet(self_: PyRef<Self>, limit: Option<usize>) {
setter!(self_, WordPieceTrainer, @set_limit_alphabet, limit);
}
#[getter]
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
getter!(
self_,
WordPieceTrainer,
initial_alphabet().iter().map(|c| c.to_string()).collect()
)
}
#[setter]
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
setter!(
self_,
WordPieceTrainer,
@set_initial_alphabet,
alphabet.into_iter().map(|c| c.0).collect()
);
}
#[getter]
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> Option<String> {
getter!(self_, WordPieceTrainer, continuing_subword_prefix().clone())
}
#[setter]
fn set_continuing_subword_prefix(self_: PyRef<Self>, prefix: Option<String>) {
setter!(self_, WordPieceTrainer, @set_continuing_subword_prefix, prefix);
}
#[getter]
fn get_end_of_word_suffix(self_: PyRef<Self>) -> Option<String> {
getter!(self_, WordPieceTrainer, end_of_word_suffix().clone())
}
#[setter]
fn set_end_of_word_suffix(self_: PyRef<Self>, suffix: Option<String>) {
setter!(self_, WordPieceTrainer, @set_end_of_word_suffix, suffix);
}
#[new]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
@ -256,10 +511,7 @@ impl PyWordPieceTrainer {
}
}
Ok((
PyWordPieceTrainer {},
PyTrainer::new(Arc::new(builder.build().into())),
))
Ok((PyWordPieceTrainer {}, builder.build().into()))
}
}
@ -281,6 +533,73 @@ impl PyWordPieceTrainer {
pub struct PyWordLevelTrainer {}
#[pymethods]
impl PyWordLevelTrainer {
#[getter]
fn get_vocab_size(self_: PyRef<Self>) -> usize {
getter!(self_, WordLevelTrainer, vocab_size)
}
#[setter]
fn set_vocab_size(self_: PyRef<Self>, vocab_size: usize) {
setter!(self_, WordLevelTrainer, vocab_size, vocab_size);
}
#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
getter!(self_, WordLevelTrainer, min_frequency)
}
#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
setter!(self_, WordLevelTrainer, min_frequency, freq);
}
#[getter]
fn get_show_progress(self_: PyRef<Self>) -> bool {
getter!(self_, WordLevelTrainer, show_progress)
}
#[setter]
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
setter!(self_, WordLevelTrainer, show_progress, show_progress);
}
#[getter]
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
getter!(
self_,
WordLevelTrainer,
special_tokens
.iter()
.map(|tok| tok.clone().into())
.collect()
)
}
#[setter]
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
setter!(
self_,
WordLevelTrainer,
special_tokens,
special_tokens
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
"Special tokens must be a List[Union[str, AddedToken]]",
))
}
})
.collect::<PyResult<Vec<_>>>()?
);
Ok(())
}
#[new]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
@ -327,12 +646,10 @@ impl PyWordLevelTrainer {
Ok((
PyWordLevelTrainer {},
PyTrainer::new(Arc::new(
builder
.build()
.expect("WordLevelTrainerBuilder cannot fail")
.into(),
)),
builder
.build()
.expect("WordLevelTrainerBuilder cannot fail")
.into(),
))
}
}
@ -359,6 +676,82 @@ impl PyWordLevelTrainer {
pub struct PyUnigramTrainer {}
#[pymethods]
impl PyUnigramTrainer {
#[getter]
fn get_vocab_size(self_: PyRef<Self>) -> u32 {
getter!(self_, UnigramTrainer, vocab_size)
}
#[setter]
fn set_vocab_size(self_: PyRef<Self>, vocab_size: u32) {
setter!(self_, UnigramTrainer, vocab_size, vocab_size);
}
#[getter]
fn get_show_progress(self_: PyRef<Self>) -> bool {
getter!(self_, UnigramTrainer, show_progress)
}
#[setter]
fn set_show_progress(self_: PyRef<Self>, show_progress: bool) {
setter!(self_, UnigramTrainer, show_progress, show_progress);
}
#[getter]
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
getter!(
self_,
UnigramTrainer,
special_tokens
.iter()
.map(|tok| tok.clone().into())
.collect()
)
}
#[setter]
fn set_special_tokens(self_: PyRef<Self>, special_tokens: &PyList) -> PyResult<()> {
setter!(
self_,
UnigramTrainer,
special_tokens,
special_tokens
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
"Special tokens must be a List[Union[str, AddedToken]]",
))
}
})
.collect::<PyResult<Vec<_>>>()?
);
Ok(())
}
#[getter]
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
getter!(
self_,
UnigramTrainer,
initial_alphabet.iter().map(|c| c.to_string()).collect()
)
}
#[setter]
fn set_initial_alphabet(self_: PyRef<Self>, alphabet: Vec<PyChar>) {
setter!(
self_,
UnigramTrainer,
initial_alphabet,
alphabet.into_iter().map(|c| c.0).collect()
);
}
#[new]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
@ -416,9 +809,6 @@ impl PyUnigramTrainer {
builder.build().map_err(|e| {
exceptions::PyException::new_err(format!("Cannot build UnigramTrainer: {}", e))
})?;
Ok((
PyUnigramTrainer {},
PyTrainer::new(Arc::new(trainer.into())),
))
Ok((PyUnigramTrainer {}, trainer.into()))
}
}

View File

@ -4,6 +4,7 @@ import pickle
from tokenizers import (
SentencePieceUnigramTokenizer,
AddedToken,
models,
pre_tokenizers,
normalizers,
@ -13,6 +14,119 @@ from tokenizers import (
from ..utils import data_dir, train_files
class TestBPETrainer:
def test_can_modify(self):
trainer = trainers.BpeTrainer(
vocab_size=12345,
min_frequency=12,
show_progress=False,
special_tokens=["1", "2"],
limit_alphabet=13,
initial_alphabet=["a", "b", "c"],
continuing_subword_prefix="pref",
end_of_word_suffix="suf",
)
assert trainer.vocab_size == 12345
assert trainer.min_frequency == 12
assert trainer.show_progress == False
assert trainer.special_tokens == [
AddedToken("1"),
AddedToken("2"),
]
assert trainer.limit_alphabet == 13
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
assert trainer.continuing_subword_prefix == "pref"
assert trainer.end_of_word_suffix == "suf"
# Modify these
trainer.vocab_size = 20000
assert trainer.vocab_size == 20000
trainer.min_frequency = 1
assert trainer.min_frequency == 1
trainer.show_progress = True
assert trainer.show_progress == True
trainer.special_tokens = []
assert trainer.special_tokens == []
trainer.limit_alphabet = None
assert trainer.limit_alphabet == None
trainer.initial_alphabet = ["d", "z"]
assert sorted(trainer.initial_alphabet) == ["d", "z"]
trainer.continuing_subword_prefix = None
assert trainer.continuing_subword_prefix == None
trainer.end_of_word_suffix = None
assert trainer.continuing_subword_prefix == None
class TestWordPieceTrainer:
def test_can_modify(self):
trainer = trainers.WordPieceTrainer(
vocab_size=12345,
min_frequency=12,
show_progress=False,
special_tokens=["1", "2"],
limit_alphabet=13,
initial_alphabet=["a", "b", "c"],
continuing_subword_prefix="pref",
end_of_word_suffix="suf",
)
assert trainer.vocab_size == 12345
assert trainer.min_frequency == 12
assert trainer.show_progress == False
assert trainer.special_tokens == [
AddedToken("1"),
AddedToken("2"),
]
assert trainer.limit_alphabet == 13
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
assert trainer.continuing_subword_prefix == "pref"
assert trainer.end_of_word_suffix == "suf"
# Modify these
trainer.vocab_size = 20000
assert trainer.vocab_size == 20000
trainer.min_frequency = 1
assert trainer.min_frequency == 1
trainer.show_progress = True
assert trainer.show_progress == True
trainer.special_tokens = []
assert trainer.special_tokens == []
trainer.limit_alphabet = None
assert trainer.limit_alphabet == None
trainer.initial_alphabet = ["d", "z"]
assert sorted(trainer.initial_alphabet) == ["d", "z"]
trainer.continuing_subword_prefix = None
assert trainer.continuing_subword_prefix == None
trainer.end_of_word_suffix = None
assert trainer.continuing_subword_prefix == None
class TestWordLevelTrainer:
def test_can_modify(self):
trainer = trainers.WordLevelTrainer(
vocab_size=12345, min_frequency=12, show_progress=False, special_tokens=["1", "2"]
)
assert trainer.vocab_size == 12345
assert trainer.min_frequency == 12
assert trainer.show_progress == False
assert trainer.special_tokens == [
AddedToken("1"),
AddedToken("2"),
]
# Modify these
trainer.vocab_size = 20000
assert trainer.vocab_size == 20000
trainer.min_frequency = 1
assert trainer.min_frequency == 1
trainer.show_progress = True
assert trainer.show_progress == True
trainer.special_tokens = []
assert trainer.special_tokens == []
class TestUnigram:
def test_train(self, train_files):
tokenizer = SentencePieceUnigramTokenizer()
@ -99,3 +213,29 @@ class TestUnigram:
with pytest.raises(Exception, match="UnigramTrainer can only train a Unigram"):
tokenizer.train([], trainer)
def test_can_modify(self):
trainer = trainers.UnigramTrainer(
vocab_size=12345,
show_progress=False,
special_tokens=["1", AddedToken("2", lstrip=True)],
initial_alphabet=["a", "b", "c"],
)
assert trainer.vocab_size == 12345
assert trainer.show_progress == False
assert trainer.special_tokens == [
AddedToken("1", normalized=False),
AddedToken("2", lstrip=True, normalized=False),
]
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
# Modify these
trainer.vocab_size = 20000
assert trainer.vocab_size == 20000
trainer.show_progress = True
assert trainer.show_progress == True
trainer.special_tokens = []
assert trainer.special_tokens == []
trainer.initial_alphabet = ["d", "z"]
assert sorted(trainer.initial_alphabet) == ["d", "z"]

View File

@ -154,24 +154,26 @@ impl BpeTrainerBuilder {
/// let mut model = BPE::default();
/// let special_tokens = trainer.train(word_counts, &mut model).unwrap();
/// ```
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub struct BpeTrainer {
/// The minimum frequency a pair must have to produce a merge operation
min_frequency: u32,
pub min_frequency: u32,
/// The target vocabulary size
vocab_size: usize,
pub vocab_size: usize,
/// Whether to show progress while training
show_progress: bool,
pub show_progress: bool,
/// A list of special tokens that the model should know of
special_tokens: Vec<AddedToken>,
pub special_tokens: Vec<AddedToken>,
/// Whether to limit the number of initial tokens that can be kept before computing merges
limit_alphabet: Option<usize>,
pub limit_alphabet: Option<usize>,
/// The initial alphabet we want absolutely to include. This allows to cover
/// some characters that are not necessarily in the training set
initial_alphabet: HashSet<char>,
pub initial_alphabet: HashSet<char>,
/// An optional prefix to use on any subword that exist only behind another one
continuing_subword_prefix: Option<String>,
pub continuing_subword_prefix: Option<String>,
/// An optional suffix to caracterize and end-of-word subword
end_of_word_suffix: Option<String>,
pub end_of_word_suffix: Option<String>,
}
impl Default for BpeTrainer {

View File

@ -36,28 +36,29 @@ fn to_log_prob(pieces: &mut [SentencePiece]) {
}
/// A `UnigramTrainer` can train a `Unigram` model from `word_counts`.
#[non_exhaustive]
#[derive(Builder, Debug, Clone)]
pub struct UnigramTrainer {
#[builder(default = "true")]
show_progress: bool,
pub show_progress: bool,
#[builder(default = "8000")]
vocab_size: u32,
pub vocab_size: u32,
#[builder(default = "2")]
n_sub_iterations: u32,
pub n_sub_iterations: u32,
#[builder(default = "0.75")]
shrinking_factor: f64,
pub shrinking_factor: f64,
#[builder(default = "vec![]")]
special_tokens: Vec<AddedToken>,
pub special_tokens: Vec<AddedToken>,
#[builder(default = "HashSet::new()")]
initial_alphabet: HashSet<char>,
pub initial_alphabet: HashSet<char>,
#[builder(default = "None")]
unk_token: Option<String>,
pub unk_token: Option<String>,
#[builder(default = "16")]
max_piece_length: usize,
pub max_piece_length: usize,
#[builder(default = "1_000_000")]
seed_size: usize,
pub seed_size: usize,
}
impl Default for UnigramTrainer {

View File

@ -2,20 +2,21 @@ use super::WordLevel;
use crate::{AddedToken, Result, Trainer};
use std::collections::HashMap;
#[non_exhaustive]
#[derive(Debug, Clone, Builder)]
pub struct WordLevelTrainer {
/// The minimum frequency a word must have to be part of the vocabulary
#[builder(default)]
min_frequency: u32,
pub min_frequency: u32,
/// The target vocabulary size
#[builder(default)]
vocab_size: usize,
pub vocab_size: usize,
/// Whether to show progress while training
#[builder(default)]
show_progress: bool,
pub show_progress: bool,
/// A list of special tokens that the model should know of
#[builder(default)]
special_tokens: Vec<AddedToken>,
pub special_tokens: Vec<AddedToken>,
}
impl Default for WordLevelTrainer {

View File

@ -85,6 +85,70 @@ pub struct WordPieceTrainer {
}
impl WordPieceTrainer {
pub fn min_frequency(&self) -> u32 {
self.bpe_trainer.min_frequency
}
pub fn set_min_frequency(&mut self, freq: u32) {
self.bpe_trainer.min_frequency = freq;
}
pub fn vocab_size(&self) -> usize {
self.bpe_trainer.vocab_size
}
pub fn set_vocab_size(&mut self, size: usize) {
self.bpe_trainer.vocab_size = size;
}
pub fn show_progress(&self) -> bool {
self.bpe_trainer.show_progress
}
pub fn set_show_progress(&mut self, show_progress: bool) {
self.bpe_trainer.show_progress = show_progress;
}
pub fn special_tokens(&self) -> &[AddedToken] {
&self.bpe_trainer.special_tokens
}
pub fn set_special_tokens(&mut self, special_tokens: Vec<AddedToken>) {
self.bpe_trainer.special_tokens = special_tokens;
}
pub fn limit_alphabet(&self) -> Option<usize> {
self.bpe_trainer.limit_alphabet
}
pub fn set_limit_alphabet(&mut self, limit: Option<usize>) {
self.bpe_trainer.limit_alphabet = limit;
}
pub fn initial_alphabet(&self) -> &HashSet<char> {
&self.bpe_trainer.initial_alphabet
}
pub fn set_initial_alphabet(&mut self, alphabet: HashSet<char>) {
self.bpe_trainer.initial_alphabet = alphabet;
}
pub fn continuing_subword_prefix(&self) -> &Option<String> {
&self.bpe_trainer.continuing_subword_prefix
}
pub fn set_continuing_subword_prefix(&mut self, prefix: Option<String>) {
self.bpe_trainer.continuing_subword_prefix = prefix;
}
pub fn end_of_word_suffix(&self) -> &Option<String> {
&self.bpe_trainer.end_of_word_suffix
}
pub fn set_end_of_word_suffix(&mut self, suffix: Option<String>) {
self.bpe_trainer.end_of_word_suffix = suffix;
}
pub fn builder() -> WordPieceTrainerBuilder {
WordPieceTrainerBuilder::default()
}