diff --git a/bindings/python/py_src/tokenizers/trainers/__init__.py b/bindings/python/py_src/tokenizers/trainers/__init__.py index 05243aa5..22f94c50 100644 --- a/bindings/python/py_src/tokenizers/trainers/__init__.py +++ b/bindings/python/py_src/tokenizers/trainers/__init__.py @@ -4,4 +4,5 @@ from .. import trainers Trainer = trainers.Trainer BpeTrainer = trainers.BpeTrainer UnigramTrainer = trainers.UnigramTrainer +WordLevelTrainer = trainers.WordLevelTrainer WordPieceTrainer = trainers.WordPieceTrainer diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 7cd727ff..d43de428 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -44,6 +44,7 @@ fn trainers(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; Ok(()) } diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 3db21851..96743331 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -242,6 +242,69 @@ impl PyWordPieceTrainer { } } +/// Capable of training a WorldLevel model +/// +/// Args: +/// vocab_size: unsigned int: +/// The size of the final vocabulary, including all tokens and alphabet. +/// +/// min_frequency: unsigned int: +/// The minimum frequency a pair should have in order to be merged. +/// +/// show_progress: boolean: +/// Whether to show progress bars while training. +/// +/// special_tokens: List[Union[str, AddedToken]]: +/// A list of special tokens the model should know of. +/// +/// Returns: +/// Trainer +#[pyclass(extends=PyTrainer, name=WordLevelTrainer)] +pub struct PyWordLevelTrainer {} +#[pymethods] +impl PyWordLevelTrainer { + /// Create a new WordLevelTrainer with the given configuration + #[new] + #[args(kwargs = "**")] + pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> { + let mut trainer = tk::models::wordlevel::WordLevelTrainer::default(); + + if let Some(kwargs) = kwargs { + for (key, val) in kwargs { + let key: &str = key.extract()?; + match key { + "vocab_size" => trainer.vocab_size = val.extract()?, + "min_frequency" => trainer.min_frequency = val.extract()?, + "show_progress" => trainer.show_progress = val.extract()?, + "special_tokens" => { + trainer.special_tokens = val + .cast_as::()? + .into_iter() + .map(|token| { + if let Ok(content) = token.extract::() { + Ok(PyAddedToken::from(content, Some(true)).get_token()) + } else if let Ok(mut token) = + token.extract::>() + { + token.is_special_token = true; + Ok(token.get_token()) + } else { + Err(exceptions::PyTypeError::new_err( + "special_tokens must be a List[Union[str, AddedToken]]", + )) + } + }) + .collect::>>()? + } + _ => println!("Ignored unknown kwargs option {}", key), + } + } + } + + Ok((PyWordLevelTrainer {}, PyTrainer::new(trainer.into()))) + } +} + /// Capable of training a Unigram model /// /// Args: diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 7e1ac440..88d731d7 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -8,6 +8,10 @@ use std::io::{BufReader, Read, Write}; use std::path::{Path, PathBuf}; mod serialization; +mod trainer; + +// Re-export +pub use trainer::*; type Vocab = HashMap; diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs new file mode 100644 index 00000000..9fa2b664 --- /dev/null +++ b/tokenizers/src/models/wordlevel/trainer.rs @@ -0,0 +1,116 @@ +use super::WordLevel; +use crate::{AddedToken, Result, Trainer}; +use std::collections::HashMap; + +pub struct WordLevelTrainer { + /// The minimum frequency a word must have to be part of the vocabulary + pub min_frequency: u32, + /// The target vocabulary size + pub vocab_size: usize, + /// Whether to show progress while training + pub show_progress: bool, + /// A list of special tokens that the model should know of + pub special_tokens: Vec, +} + +impl Default for WordLevelTrainer { + fn default() -> Self { + Self { + min_frequency: 0, + vocab_size: 30_000, + show_progress: true, + special_tokens: vec![], + } + } +} + +impl WordLevelTrainer { + fn train(&self, word_counts: HashMap) -> Result<(WordLevel, Vec)> { + let mut ordered_counts = word_counts.into_iter().collect::>(); + ordered_counts.sort_by_key(|(_, n)| std::cmp::Reverse(*n)); + let word_level = WordLevel::builder() + .vocab( + self.special_tokens + .iter() + .map(|token| token.content.clone()) + .chain( + ordered_counts + .into_iter() + .filter(|(_, n)| *n >= self.min_frequency) + .map(|(w, _)| w), + ) + .take(self.vocab_size) + .enumerate() + .map(|(i, w)| (w, i as u32)) + .collect(), + ) + .build(); + + Ok((word_level, self.special_tokens.clone())) + } +} + +impl Trainer for WordLevelTrainer { + type Model = WordLevel; + + /// Train a WordLevel model + fn train(&self, word_counts: HashMap) -> Result<(WordLevel, Vec)> { + self.train(word_counts) + } + + /// Whether we should show progress + fn should_show_progress(&self) -> bool { + self.show_progress + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_train() { + let word_counts: HashMap = [ + ("the".into(), 25), + ("roses".into(), 22), + ("are".into(), 24), + ("red".into(), 12), + ("voilets".into(), 10), + ("blue".into(), 16), + ] + .iter() + .cloned() + .collect(); + + let mut trainer = WordLevelTrainer::default(); + trainer.vocab_size = 5; + + let (model, _) = trainer.train(word_counts.clone()).unwrap(); + let expected_vocab: HashMap = [ + ("the".into(), 0), + ("are".into(), 1), + ("roses".into(), 2), + ("blue".into(), 3), + ("red".into(), 4), + ] + .iter() + .cloned() + .collect(); + assert_eq!(model.vocab, expected_vocab); + + // If we specify a min_frequency + trainer.min_frequency = 15; + let (model, _) = trainer.train(word_counts).unwrap(); + let expected_vocab: HashMap = [ + ("the".into(), 0), + ("are".into(), 1), + ("roses".into(), 2), + ("blue".into(), 3), + ] + .iter() + .cloned() + .collect(); + + assert_eq!(model.vocab, expected_vocab); + } +}