Improve docs and fix tests around training

This commit is contained in:
Anthony MOI
2020-11-27 16:44:17 -05:00
committed by Anthony MOI
parent 06f6ba3fce
commit 3a8627ce4d
9 changed files with 101 additions and 24 deletions

View File

@@ -15,5 +15,6 @@ def batch_iterator():
for i in range(0, len(dataset["train"]), batch_length): for i in range(0, len(dataset["train"]), batch_length):
yield dataset["train"][i : i + batch_length]["text"] yield dataset["train"][i : i + batch_length]["text"]
# And finally train # And finally train
bpe_tokenizer.train_from_iterator(batch_iterator(), length=len(dataset["train"])) bpe_tokenizer.train_from_iterator(batch_iterator(), length=len(dataset["train"]))

View File

@@ -1022,6 +1022,45 @@ class Tokenizer:
:obj:`Optional[int]`: An optional id, :obj:`None` if out of vocabulary :obj:`Optional[int]`: An optional id, :obj:`None` if out of vocabulary
""" """
pass pass
def train(self, files, trainer=None):
"""
Train the Tokenizer using the given files.
Reads the files line by line, while keeping all the whitespace, even new lines.
If you want to train from data store in-memory, you can check
:meth:`~tokenizers.Tokenizer.train_from_iterator`
Args:
files (:obj:`List[str]`):
A list of path to the files that we should use for training
trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`):
An optional trainer that should be used to train our Model
"""
pass
def train_from_iterator(self, iterator, trainer=None, length=None):
"""
Train the Tokenizer using the provided iterator.
You can provide anything that is a Python Iterator
* A list of sequences :obj:`List[str]`
* A generator that yields :obj:`str` or :obj:`List[str]`
* A Numpy array of strings
* ...
Args:
iterator (:obj:`Iterator`):
Any iterator over strings or list of strings
trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`):
An optional trainer that should be used to train our Model
length (:obj:`int`, `optional`):
The total number of sequences in the iterator. This is used to
provide meaningful progress tracking
"""
pass
@property @property
def truncation(self): def truncation(self):
""" """

View File

@@ -1068,7 +1068,20 @@ impl PyTokenizer {
Ok(self.tokenizer.add_special_tokens(&tokens)) Ok(self.tokenizer.add_special_tokens(&tokens))
} }
/// Train the Tokenizer using the given files.
///
/// Reads the files line by line, while keeping all the whitespace, even new lines.
/// If you want to train from data store in-memory, you can check
/// :meth:`~tokenizers.Tokenizer.train_from_iterator`
///
/// Args:
/// files (:obj:`List[str]`):
/// A list of path to the files that we should use for training
///
/// trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`):
/// An optional trainer that should be used to train our Model
#[args(trainer = "None")] #[args(trainer = "None")]
#[text_signature = "(self, files, trainer = None)"]
fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> { fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> {
let mut trainer = let mut trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone()); trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
@@ -1084,7 +1097,27 @@ impl PyTokenizer {
}) })
} }
/// Train the Tokenizer using the provided iterator.
///
/// You can provide anything that is a Python Iterator
///
/// * A list of sequences :obj:`List[str]`
/// * A generator that yields :obj:`str` or :obj:`List[str]`
/// * A Numpy array of strings
/// * ...
///
/// Args:
/// iterator (:obj:`Iterator`):
/// Any iterator over strings or list of strings
///
/// trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`):
/// An optional trainer that should be used to train our Model
///
/// length (:obj:`int`, `optional`):
/// The total number of sequences in the iterator. This is used to
/// provide meaningful progress tracking
#[args(trainer = "None", length = "None")] #[args(trainer = "None", length = "None")]
#[text_signature = "(self, iterator, trainer=None, length=None)"]
fn train_from_iterator( fn train_from_iterator(
&mut self, &mut self,
iterator: &PyAny, iterator: &PyAny,

View File

@@ -71,7 +71,7 @@ use std::path::Path;
fn main() -> Result<()> { fn main() -> Result<()> {
let vocab_size: usize = 100; let vocab_size: usize = 100;
let trainer = BpeTrainerBuilder::new() let mut trainer = BpeTrainerBuilder::new()
.show_progress(true) .show_progress(true)
.vocab_size(vocab_size) .vocab_size(vocab_size)
.min_frequency(0) .min_frequency(0)
@@ -97,8 +97,8 @@ fn main() -> Result<()> {
let pretty = false; let pretty = false;
tokenizer tokenizer
.train( .train_from_files(
&trainer, &mut trainer,
vec!["path/to/vocab.txt".to_string()], vec!["path/to/vocab.txt".to_string()],
)? )?
.save("tokenizer.json", pretty)?; .save("tokenizer.json", pretty)?;

View File

@@ -58,7 +58,7 @@
//! fn main() -> Result<()> { //! fn main() -> Result<()> {
//! let vocab_size: usize = 100; //! let vocab_size: usize = 100;
//! //!
//! let trainer = BpeTrainerBuilder::new() //! let mut trainer = BpeTrainerBuilder::new()
//! .show_progress(true) //! .show_progress(true)
//! .vocab_size(vocab_size) //! .vocab_size(vocab_size)
//! .min_frequency(0) //! .min_frequency(0)
@@ -84,8 +84,8 @@
//! //!
//! let pretty = false; //! let pretty = false;
//! tokenizer //! tokenizer
//! .train( //! .train_from_files(
//! &trainer, //! &mut trainer,
//! vec!["path/to/vocab.txt".to_string()], //! vec!["path/to/vocab.txt".to_string()],
//! )? //! )?
//! .save("tokenizer.json", pretty)?; //! .save("tokenizer.json", pretty)?;

View File

@@ -138,22 +138,21 @@ impl BpeTrainerBuilder {
} }
} }
/// In charge of training a `BPE` model from a mapping of words to word counts. /// In charge of training a `BPE` model
/// ///
/// # Examples /// # Examples
/// ///
/// ``` /// ```
/// use std::collections::HashMap;
/// use tokenizers::tokenizer::Trainer; /// use tokenizers::tokenizer::Trainer;
/// use tokenizers::models::bpe::{BPE, BpeTrainer}; /// use tokenizers::models::bpe::{BPE, BpeTrainer};
/// ///
/// let word_counts: HashMap<String, u32> = [ /// let sequences = vec![ "Hello", "World" ];
/// (String::from("Hello"), 1), ///
/// (String::from("World"), 1), /// let mut trainer = BpeTrainer::default();
/// ].iter().cloned().collect(); /// trainer.feed(sequences.iter(), |s| Ok(vec![s.to_owned()]));
/// let trainer = BpeTrainer::default(); ///
/// let mut model = BPE::default(); /// let mut model = BPE::default();
/// let special_tokens = trainer.train(word_counts, &mut model).unwrap(); /// let special_tokens = trainer.train(&mut model).unwrap();
/// ``` /// ```
#[non_exhaustive] #[non_exhaustive]
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]

View File

@@ -20,7 +20,7 @@ fn train_tokenizer() {
.build() .build()
.unwrap(); .unwrap();
let trainer = BpeTrainerBuilder::new() let mut trainer = BpeTrainerBuilder::new()
.show_progress(false) .show_progress(false)
.vocab_size(vocab_size) .vocab_size(vocab_size)
.min_frequency(0) .min_frequency(0)
@@ -35,7 +35,7 @@ fn train_tokenizer() {
let pretty = true; let pretty = true;
tokenizer tokenizer
.train(&trainer, vec!["data/small.txt".to_string()]) .train_from_files(&mut trainer, vec!["data/small.txt".to_string()])
.unwrap() .unwrap()
.save("data/tokenizer.json", pretty) .save("data/tokenizer.json", pretty)
.unwrap(); .unwrap();
@@ -80,7 +80,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> {
// START quicktour_init_trainer // START quicktour_init_trainer
use tokenizers::models::bpe::BpeTrainer; use tokenizers::models::bpe::BpeTrainer;
let trainer = BpeTrainer::builder() let mut trainer = BpeTrainer::builder()
.special_tokens(vec![ .special_tokens(vec![
AddedToken::from("[UNK]", true), AddedToken::from("[UNK]", true),
AddedToken::from("[CLS]", true), AddedToken::from("[CLS]", true),
@@ -102,7 +102,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> {
"data/wikitext-103-raw/wiki.test.raw".into(), "data/wikitext-103-raw/wiki.test.raw".into(),
"data/wikitext-103-raw/wiki.valid.raw".into(), "data/wikitext-103-raw/wiki.valid.raw".into(),
]; ];
tokenizer.train(&trainer, files)?; tokenizer.train_from_files(&mut trainer, files)?;
// END quicktour_train // END quicktour_train
// START quicktour_save // START quicktour_save
tokenizer.save("data/tokenizer-wiki.json", false)?; tokenizer.save("data/tokenizer-wiki.json", false)?;
@@ -403,7 +403,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
// START bert_train_tokenizer // START bert_train_tokenizer
use tokenizers::models::{wordpiece::WordPieceTrainer, TrainerWrapper}; use tokenizers::models::{wordpiece::WordPieceTrainer, TrainerWrapper};
let trainer: TrainerWrapper = WordPieceTrainer::builder() let mut trainer: TrainerWrapper = WordPieceTrainer::builder()
.vocab_size(30_522) .vocab_size(30_522)
.special_tokens(vec![ .special_tokens(vec![
AddedToken::from("[UNK]", true), AddedToken::from("[UNK]", true),
@@ -419,7 +419,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
"data/wikitext-103-raw/wiki.test.raw".into(), "data/wikitext-103-raw/wiki.test.raw".into(),
"data/wikitext-103-raw/wiki.valid.raw".into(), "data/wikitext-103-raw/wiki.valid.raw".into(),
]; ];
bert_tokenizer.train(&trainer, files)?; bert_tokenizer.train_from_files(&mut trainer, files)?;
bert_tokenizer.save("data/bert-wiki.json", false)?; bert_tokenizer.save("data/bert-wiki.json", false)?;
// END bert_train_tokenizer // END bert_train_tokenizer

View File

@@ -20,9 +20,9 @@ fn bpe_values_after_training() {
) )
.build() .build()
.unwrap(); .unwrap();
let trainer = tokenizer.get_model().get_trainer(); let mut trainer = tokenizer.get_model().get_trainer();
tokenizer tokenizer
.train(&trainer, vec!["./data/small.txt".to_string()]) .train_from_files(&mut trainer, vec!["./data/small.txt".to_string()])
.unwrap(); .unwrap();
assert_eq!(tokenizer.get_model().dropout, Some(0.1)); assert_eq!(tokenizer.get_model().dropout, Some(0.1));
assert_eq!(tokenizer.get_model().unk_token, Some("[UNK]".to_string())); assert_eq!(tokenizer.get_model().unk_token, Some("[UNK]".to_string()));

View File

@@ -7,7 +7,7 @@ use std::path::Path;
use tokenizers::models::unigram::Lattice; use tokenizers::models::unigram::Lattice;
use tokenizers::models::unigram::Unigram; use tokenizers::models::unigram::Unigram;
use tokenizers::models::unigram::UnigramTrainer; use tokenizers::models::unigram::UnigramTrainer;
use tokenizers::tokenizer::{Model, Trainer}; use tokenizers::tokenizer::Model;
#[test] #[test]
fn test_unigram_from_file() { fn test_unigram_from_file() {
@@ -56,7 +56,12 @@ fn test_train_unigram_from_file() {
.build() .build()
.unwrap(); .unwrap();
let mut model = Unigram::default(); let mut model = Unigram::default();
trainer.train(word_counts, &mut model).unwrap();
let sentences: Vec<_> = word_counts
.iter()
.map(|(s, i)| (s.to_owned(), *i))
.collect();
trainer.do_train(sentences, &mut model).unwrap();
assert_eq!(model.get_vocab_size(), 719); assert_eq!(model.get_vocab_size(), 719);
} }