Add ability to train from Iterator

This commit is contained in:
Anthony MOI
2020-11-12 12:58:14 -05:00
committed by Anthony MOI
parent 6e364cb685
commit e0a70f1fb2
11 changed files with 380 additions and 199 deletions

View File

@ -65,6 +65,12 @@ version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "clap"
version = "2.33.3"
@ -106,27 +112,68 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "const_fn"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c478836e029dcef17fb47c89023448c64f781a046e0300e257ad8225ae59afab"
[[package]]
name = "crossbeam"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd01a6eb3daaafa260f6fc94c3a6c36390abc2080e38e3e34ced87393fb77d80"
dependencies = [
"cfg-if 1.0.0",
"crossbeam-channel 0.5.0",
"crossbeam-deque 0.8.0",
"crossbeam-epoch 0.9.0",
"crossbeam-queue",
"crossbeam-utils 0.8.0",
]
[[package]]
name = "crossbeam-channel"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b153fe7cbef478c567df0f972e02e6d736db11affe43dfc9c56a9374d1adfb87"
dependencies = [
"crossbeam-utils",
"crossbeam-utils 0.7.2",
"maybe-uninit",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dca26ee1f8d361640700bde38b2c37d8c22b3ce2d360e1fc1c74ea4b0aa7d775"
dependencies = [
"cfg-if 1.0.0",
"crossbeam-utils 0.8.0",
]
[[package]]
name = "crossbeam-deque"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f02af974daeee82218205558e51ec8768b48cf524bd01d550abe5573a608285"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
"crossbeam-epoch 0.8.2",
"crossbeam-utils 0.7.2",
"maybe-uninit",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9"
dependencies = [
"cfg-if 1.0.0",
"crossbeam-epoch 0.9.0",
"crossbeam-utils 0.8.0",
]
[[package]]
name = "crossbeam-epoch"
version = "0.8.2"
@ -134,14 +181,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "058ed274caafc1f60c4997b5fc07bf7dc7cca454af7c6e81edffe5f33f70dace"
dependencies = [
"autocfg",
"cfg-if",
"crossbeam-utils",
"cfg-if 0.1.10",
"crossbeam-utils 0.7.2",
"lazy_static",
"maybe-uninit",
"memoffset",
"scopeguard",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0f606a85340376eef0d6d8fec399e6d4a544d648386c6645eb6d0653b27d9f"
dependencies = [
"cfg-if 1.0.0",
"const_fn",
"crossbeam-utils 0.8.0",
"lazy_static",
"memoffset",
"scopeguard",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b2a58563f049aa3bae172bc4120f093b5901161c629f280a1f40ba55317d774"
dependencies = [
"cfg-if 1.0.0",
"crossbeam-utils 0.8.0",
]
[[package]]
name = "crossbeam-utils"
version = "0.7.2"
@ -149,7 +220,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3c7c73a2d1e9fc0886a08b93e98eb643461230d5f1925e4036204d5f2e261a8"
dependencies = [
"autocfg",
"cfg-if",
"cfg-if 0.1.10",
"lazy_static",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec91540d98355f690a86367e566ecad2e9e579f230230eb7c21398372be73ea5"
dependencies = [
"autocfg",
"cfg-if 1.0.0",
"const_fn",
"lazy_static",
]
@ -269,7 +352,7 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc587bc0ec293155d5bfa6b9891ec18a1e330c234f896ea47fbada4cadbe47e6"
dependencies = [
"cfg-if",
"cfg-if 0.1.10",
"libc",
"wasi",
]
@ -350,7 +433,7 @@ version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63312a18f7ea8760cdd0a7c5aac1a619752a246b833545e3e36d1f81f7cd9e66"
dependencies = [
"cfg-if",
"cfg-if 0.1.10",
]
[[package]]
@ -413,7 +496,7 @@ checksum = "db65c6da02e61f55dae90a0ae427b2a5f6b3e8db09f58d10efab23af92592616"
dependencies = [
"arrayvec",
"bitflags",
"cfg-if",
"cfg-if 0.1.10",
"ryu",
"static_assertions",
]
@ -439,7 +522,7 @@ version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b"
dependencies = [
"cfg-if",
"cfg-if 0.1.10",
]
[[package]]
@ -546,7 +629,7 @@ name = "numpy"
version = "0.11.0"
source = "git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01#e331befa27fede78d4662edf08fa0508db39be01"
dependencies = [
"cfg-if",
"cfg-if 0.1.10",
"libc",
"ndarray",
"num-complex",
@ -593,7 +676,7 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c361aa727dd08437f2f1447be8b59a33b0edd15e0fcee698f935613d9efbca9b"
dependencies = [
"cfg-if",
"cfg-if 0.1.10",
"cloudabi",
"instant",
"libc",
@ -755,7 +838,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf6960dc9a5b4ee8d3e4c5787b4a112a8818e0290a42ff664ad60692fdf2032"
dependencies = [
"autocfg",
"crossbeam-deque",
"crossbeam-deque 0.7.3",
"either",
"rayon-core",
]
@ -776,9 +859,9 @@ version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8c4fec834fb6e6d2dd5eece3c7b432a52f0ba887cf40e595190c4107edc08bf"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-utils",
"crossbeam-channel 0.4.4",
"crossbeam-deque 0.7.3",
"crossbeam-utils 0.7.2",
"lazy_static",
"num_cpus",
]
@ -912,7 +995,7 @@ version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9"
dependencies = [
"cfg-if",
"cfg-if 0.1.10",
"libc",
"rand",
"redox_syscall",
@ -995,6 +1078,7 @@ dependencies = [
name = "tokenizers-python"
version = "0.9.4"
dependencies = [
"crossbeam",
"env_logger",
"libc",
"ndarray",

View File

@ -18,6 +18,7 @@ pyo3 = "0.12"
numpy = { git = "https://github.com/pyo3/rust-numpy/", rev = "e331befa27fede78d4662edf08fa0508db39be01" }
ndarray = "0.13"
onig = { version = "6.0", default-features = false }
crossbeam = "0.8"
[dependencies.tokenizers]
version = "*"

View File

@ -12,6 +12,7 @@ use tk::tokenizer::{
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
TruncationParams, TruncationStrategy,
};
use tk::utils::iter::ResultShunt;
use tokenizers as tk;
use super::decoders::PyDecoder;
@ -1069,16 +1070,53 @@ impl PyTokenizer {
}
#[args(trainer = "None")]
fn train(&mut self, files: Vec<String>, trainer: Option<&PyTrainer>) -> PyResult<()> {
let trainer =
fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> {
let mut trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
Python::with_gil(|py| {
py.allow_threads(|| {
ToPyResult(self.tokenizer.train(&trainer, files).map(|_| {})).into()
ToPyResult(
self.tokenizer
.train_from_files(&mut trainer, files)
.map(|_| {}),
)
.into()
})
})
}
#[args(trainer = "None")]
fn train_from_iterator(
&mut self,
iterator: &PyAny,
trainer: Option<&mut PyTrainer>,
) -> PyResult<()> {
let mut trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
let (send, recv) = std::sync::mpsc::sync_channel(256);
let mut sender = Some(send);
let iterator: PyIterator = iterator.iter()?;
crossbeam::thread::scope(|s| {
let _train_handle = s.spawn(|_| {
self.tokenizer
.train(&mut trainer, recv.into_iter())
.map(|_| {})
});
ResultShunt::process(iterator.map(|seq| seq?.extract::<&str>()), |iter| {
if let Some(send) = sender.take() {
for seq in iter {
send.send(seq)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
}
}
Ok(())
})?
})
.unwrap()
}
/// Apply all the post-processing steps to the given encodings.
///
/// The various steps are:

View File

@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use pyo3::exceptions;
@ -50,19 +49,20 @@ impl Trainer for PyTrainer {
self.trainer.read().unwrap().should_show_progress()
}
fn train(
&self,
words: HashMap<String, u32>,
model: &mut PyModel,
) -> tk::Result<Vec<tk::AddedToken>> {
fn train(&self, model: &mut PyModel) -> tk::Result<Vec<tk::AddedToken>> {
self.trainer
.read()
.unwrap()
.train(words, &mut model.model.write().unwrap())
.train(&mut model.model.write().unwrap())
}
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
self.trainer.read().unwrap().process_tokens(words, tokens)
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> tk::Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> tk::Result<Vec<String>> + Sync,
{
self.trainer.write().unwrap().feed(iterator, process)
}
}

View File

@ -133,6 +133,7 @@ impl BpeTrainerBuilder {
initial_alphabet: self.config.initial_alphabet,
continuing_subword_prefix: self.config.continuing_subword_prefix,
end_of_word_suffix: self.config.end_of_word_suffix,
words: HashMap::new(),
}
}
}
@ -174,6 +175,8 @@ pub struct BpeTrainer {
pub continuing_subword_prefix: Option<String>,
/// An optional suffix to caracterize and end-of-word subword
pub end_of_word_suffix: Option<String>,
words: HashMap<String, u32>,
}
impl Default for BpeTrainer {
@ -407,9 +410,9 @@ impl BpeTrainer {
)
}
pub fn train(
pub fn do_train(
&self,
word_counts: HashMap<String, u32>,
word_counts: &HashMap<String, u32>,
model: &mut BPE,
) -> Result<Vec<AddedToken>> {
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
@ -425,14 +428,14 @@ impl BpeTrainer {
//
// 2. Compute the initial alphabet
//
self.compute_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
self.compute_alphabet(word_counts, &mut word_to_id, &mut id_to_word);
//
// 3. Tokenize words
//
self.update_progress(&progress, word_counts.len(), "Tokenize words");
let (words, counts) =
self.tokenize_words(&word_counts, &mut word_to_id, &mut id_to_word, &progress);
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
self.finalize_progress(&progress, words.len());
//
@ -586,14 +589,45 @@ impl Trainer for BpeTrainer {
type Model = BPE;
/// Train a BPE model
fn train(&self, word_counts: HashMap<String, u32>, model: &mut BPE) -> Result<Vec<AddedToken>> {
self.train(word_counts, model)
fn train(&self, model: &mut BPE) -> Result<Vec<AddedToken>> {
self.do_train(&self.words, model)
}
/// Whether we should show progress
fn should_show_progress(&self) -> bool {
self.show_progress
}
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
let words: Result<HashMap<String, u32>> = iterator
.maybe_par_bridge()
.map(|sequence| {
let words = process(sequence.as_ref())?;
let mut map = HashMap::new();
for word in words {
map.entry(word).and_modify(|c| *c += 1).or_insert(1);
}
Ok(map)
})
.reduce(
|| Ok(HashMap::new()),
|acc, ws| {
let mut acc = acc?;
for (k, v) in ws? {
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
}
Ok(acc)
},
);
self.words = words?;
Ok(())
}
}
#[cfg(test)]
@ -624,7 +658,7 @@ mod tests {
.min_frequency(2)
.build();
let mut model = BPE::default();
trainer.train(word_counts, &mut model).unwrap();
trainer.do_train(&word_counts, &mut model).unwrap();
// Vocab should contain all of the characters from the `word_counts` mapping
// as well as three merges: 're', 'are', and 'is'.

View File

@ -145,37 +145,38 @@ impl Trainer for TrainerWrapper {
}
}
fn train(
&self,
words: HashMap<String, u32>,
model: &mut ModelWrapper,
) -> Result<Vec<AddedToken>> {
fn train(&self, model: &mut ModelWrapper) -> Result<Vec<AddedToken>> {
match self {
TrainerWrapper::BpeTrainer(t) => match model {
ModelWrapper::BPE(bpe) => t.train(words, bpe),
ModelWrapper::BPE(bpe) => t.train(bpe),
_ => Err("BpeTrainer can only train a BPE".into()),
},
TrainerWrapper::WordPieceTrainer(t) => match model {
ModelWrapper::WordPiece(wp) => t.train(words, wp),
ModelWrapper::WordPiece(wp) => t.train(wp),
_ => Err("WordPieceTrainer can only train a WordPiece".into()),
},
TrainerWrapper::WordLevelTrainer(t) => match model {
ModelWrapper::WordLevel(wl) => t.train(words, wl),
ModelWrapper::WordLevel(wl) => t.train(wl),
_ => Err("WordLevelTrainer can only train a WordLevel".into()),
},
TrainerWrapper::UnigramTrainer(t) => match model {
ModelWrapper::Unigram(u) => t.train(words, u),
ModelWrapper::Unigram(u) => t.train(u),
_ => Err("UnigramTrainer can only train a Unigram".into()),
},
}
}
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
match self {
TrainerWrapper::BpeTrainer(bpe) => bpe.process_tokens(words, tokens),
TrainerWrapper::WordPieceTrainer(wpt) => wpt.process_tokens(words, tokens),
TrainerWrapper::WordLevelTrainer(wpt) => wpt.process_tokens(words, tokens),
TrainerWrapper::UnigramTrainer(wpt) => wpt.process_tokens(words, tokens),
TrainerWrapper::BpeTrainer(bpe) => bpe.feed(iterator, process),
TrainerWrapper::WordPieceTrainer(wpt) => wpt.feed(iterator, process),
TrainerWrapper::WordLevelTrainer(wpt) => wpt.feed(iterator, process),
TrainerWrapper::UnigramTrainer(wpt) => wpt.feed(iterator, process),
}
}
}
@ -194,7 +195,7 @@ mod tests {
let trainer = TrainerWrapper::BpeTrainer(BpeTrainer::default());
let mut model = ModelWrapper::Unigram(Unigram::default());
let result = trainer.train(HashMap::new(), &mut model);
let result = trainer.train(&mut model);
assert!(result.is_err());
}
}

View File

@ -1,5 +1,6 @@
use crate::models::unigram::{lattice::Lattice, model::Unigram};
use crate::tokenizer::{AddedToken, Result, Trainer};
use crate::utils::parallelism::*;
use crate::utils::progress::{ProgressBar, ProgressStyle};
use log::debug;
use std::cmp::Reverse;
@ -58,7 +59,9 @@ pub struct UnigramTrainer {
#[builder(default = "16")]
pub max_piece_length: usize,
#[builder(default = "1_000_000")]
pub seed_size: usize,
seed_size: usize,
#[builder(default = "HashMap::new()")]
words: HashMap<String, u32>,
}
impl Default for UnigramTrainer {
@ -451,7 +454,11 @@ impl UnigramTrainer {
.collect();
new_pieces
}
pub fn _train(&self, sentences: Vec<Sentence>, model: &mut Unigram) -> Result<Vec<AddedToken>> {
pub fn do_train(
&self,
sentences: Vec<Sentence>,
model: &mut Unigram,
) -> Result<Vec<AddedToken>> {
let progress = self.setup_progress();
//
// 1. Compute frequent substrings
@ -533,19 +540,46 @@ impl Trainer for UnigramTrainer {
type Model = Unigram;
/// Train a Unigram model
fn train(
&self,
word_counts: HashMap<String, u32>,
model: &mut Unigram,
) -> Result<Vec<AddedToken>> {
let sentences: Vec<_> = word_counts.into_iter().collect();
self._train(sentences, model)
fn train(&self, model: &mut Unigram) -> Result<Vec<AddedToken>> {
let sentences: Vec<_> = self.words.iter().map(|(s, i)| (s.to_owned(), *i)).collect();
self.do_train(sentences, model)
}
/// Whether we should show progress
fn should_show_progress(&self) -> bool {
self.show_progress
}
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
let words: Result<HashMap<String, u32>> = iterator
.maybe_par_bridge()
.map(|sequence| {
let words = process(sequence.as_ref())?;
let mut map = HashMap::new();
for word in words {
map.entry(word).and_modify(|c| *c += 1).or_insert(1);
}
Ok(map)
})
.reduce(
|| Ok(HashMap::new()),
|acc, ws| {
let mut acc = acc?;
for (k, v) in ws? {
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
}
Ok(acc)
},
);
self.words = words?;
Ok(())
}
}
#[cfg(test)]
@ -640,10 +674,7 @@ mod tests {
let mut unigram = Unigram::default();
trainer
.train(
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
&mut unigram,
)
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
.unwrap();
let mut pieces = unigram.iter();
@ -665,10 +696,7 @@ mod tests {
let mut unigram = Unigram::default();
trainer
.train(
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
&mut unigram,
)
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
.unwrap();
let mut pieces = unigram.iter();
@ -684,10 +712,7 @@ mod tests {
let mut unigram = Unigram::default();
trainer
.train(
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
&mut unigram,
)
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
.unwrap();
let mut pieces = unigram.iter();
@ -707,10 +732,7 @@ mod tests {
let mut unigram = Unigram::default();
trainer
.train(
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
&mut unigram,
)
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
.unwrap();
let mut pieces = unigram.iter();

View File

@ -1,4 +1,5 @@
use super::WordLevel;
use crate::utils::parallelism::*;
use crate::{AddedToken, Result, Trainer};
use std::collections::HashMap;
@ -17,6 +18,9 @@ pub struct WordLevelTrainer {
/// A list of special tokens that the model should know of
#[builder(default)]
pub special_tokens: Vec<AddedToken>,
#[builder(default, private)]
words: HashMap<String, u32>,
}
impl Default for WordLevelTrainer {
@ -26,6 +30,7 @@ impl Default for WordLevelTrainer {
vocab_size: 30_000,
show_progress: true,
special_tokens: vec![],
words: HashMap::new(),
}
}
}
@ -35,12 +40,12 @@ impl WordLevelTrainer {
WordLevelTrainerBuilder::default()
}
fn train(
fn do_train(
&self,
word_counts: HashMap<String, u32>,
word_counts: &HashMap<String, u32>,
model: &mut WordLevel,
) -> Result<Vec<AddedToken>> {
let mut ordered_counts = word_counts.into_iter().collect::<Vec<_>>();
let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();
ordered_counts.sort_by_key(|(_, n)| std::cmp::Reverse(*n));
let word_level = WordLevel::builder()
.vocab(
@ -50,8 +55,8 @@ impl WordLevelTrainer {
.chain(
ordered_counts
.into_iter()
.filter(|(_, n)| *n >= self.min_frequency)
.map(|(w, _)| w),
.filter(|(_, n)| **n >= self.min_frequency)
.map(|(w, _)| w.to_owned()),
)
.take(self.vocab_size)
.enumerate()
@ -72,18 +77,45 @@ impl Trainer for WordLevelTrainer {
type Model = WordLevel;
/// Train a WordLevel model
fn train(
&self,
word_counts: HashMap<String, u32>,
model: &mut WordLevel,
) -> Result<Vec<AddedToken>> {
self.train(word_counts, model)
fn train(&self, model: &mut WordLevel) -> Result<Vec<AddedToken>> {
self.do_train(&self.words, model)
}
/// Whether we should show progress
fn should_show_progress(&self) -> bool {
self.show_progress
}
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
let words: Result<HashMap<String, u32>> = iterator
.maybe_par_bridge()
.map(|sequence| {
let words = process(sequence.as_ref())?;
let mut map = HashMap::new();
for word in words {
map.entry(word).and_modify(|c| *c += 1).or_insert(1);
}
Ok(map)
})
.reduce(
|| Ok(HashMap::new()),
|acc, ws| {
let mut acc = acc?;
for (k, v) in ws? {
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
}
Ok(acc)
},
);
self.words = words?;
Ok(())
}
}
#[cfg(test)]
@ -108,7 +140,7 @@ mod tests {
trainer.vocab_size = 5;
let mut model = WordLevel::default();
trainer.train(word_counts.clone(), &mut model).unwrap();
trainer.do_train(&word_counts, &mut model).unwrap();
let expected_vocab: HashMap<String, u32> = [
("the".into(), 0),
("are".into(), 1),
@ -124,7 +156,7 @@ mod tests {
// If we specify a min_frequency
trainer.min_frequency = 15;
let mut model = WordLevel::default();
trainer.train(word_counts, &mut model).unwrap();
trainer.do_train(&word_counts, &mut model).unwrap();
let expected_vocab: HashMap<String, u32> = [
("the".into(), 0),
("are".into(), 1),

View File

@ -1,7 +1,7 @@
use super::WordPiece;
use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder, BPE};
use crate::tokenizer::{AddedToken, Result, Trainer};
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
/// A `WordPieceTrainerBuilder` can be used to create a `WordPieceTrainer` with a custom
/// configuration.
@ -153,13 +153,9 @@ impl WordPieceTrainer {
WordPieceTrainerBuilder::default()
}
pub fn train(
&self,
word_counts: HashMap<String, u32>,
model: &mut WordPiece,
) -> Result<Vec<AddedToken>> {
pub fn train(&self, model: &mut WordPiece) -> Result<Vec<AddedToken>> {
let mut bpe = BPE::default();
let special_tokens = self.bpe_trainer.train(word_counts, &mut bpe)?;
let special_tokens = self.bpe_trainer.train(&mut bpe)?;
let new_wordpiece = WordPiece::from_bpe(&bpe);
// Transfer the vocab
@ -175,19 +171,20 @@ impl WordPieceTrainer {
impl Trainer for WordPieceTrainer {
type Model = WordPiece;
fn train(
&self,
word_counts: HashMap<String, u32>,
model: &mut WordPiece,
) -> Result<Vec<AddedToken>> {
self.train(word_counts, model)
}
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
self.bpe_trainer.process_tokens(&mut words, tokens)
fn train(&self, model: &mut WordPiece) -> Result<Vec<AddedToken>> {
self.train(model)
}
fn should_show_progress(&self) -> bool {
self.bpe_trainer.should_show_progress()
}
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
self.bpe_trainer.feed(iterator, process)
}
}

View File

@ -532,11 +532,15 @@ mod tests {
fn should_show_progress(&self) -> bool {
true
}
fn train(
&self,
_words: HashMap<String, u32>,
_model: &mut ModelMock,
) -> Result<Vec<AddedToken>> {
fn train(&self, _model: &mut ModelMock) -> Result<Vec<AddedToken>> {
unimplemented!()
}
fn feed<I, S, F>(&mut self, _iterator: I, _process: F) -> Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
unimplemented!()
}
}

View File

@ -131,20 +131,13 @@ pub trait Trainer {
fn should_show_progress(&self) -> bool;
/// The actual training method. This will return a new trained Model as well as a list
/// of `special_tokens` to be added directly to the tokenizer along with the model.
fn train(
&self,
words: HashMap<String, u32>,
model: &mut Self::Model,
) -> Result<Vec<AddedToken>>;
/// Process a bunch of token, counting them as relevant.
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
for token in tokens {
words
.entry(token.clone())
.and_modify(|c| *c += 1)
.or_insert(1);
}
}
fn train(&self, model: &mut Self::Model) -> Result<Vec<AddedToken>>;
/// Process an iterator of sequences already pre-processed by the Tokenizer
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync;
}
#[derive(Debug, Clone, PartialEq)]
@ -969,99 +962,74 @@ where
.collect()
}
/// Train a model and replace our current Model, using the given Trainer
fn word_count<MN, T>(&self, trainer: &T, files: Vec<String>) -> Result<HashMap<String, u32>>
pub fn train_from_files<T>(&mut self, trainer: &mut T, files: Vec<String>) -> Result<&mut Self>
where
T: Trainer<Model = MN> + Sync,
MN: Model,
T: Trainer<Model = M> + Sync,
{
let max_read = 1_000_000;
let mut len = 0;
for file in files.iter() {
len += File::open(file)
.and_then(|f| f.metadata())
.map(|m| m.len())?;
use crate::utils::iter::ResultShunt;
ResultShunt::process(
files.into_iter().flat_map(|filename| {
match File::open(filename) {
Ok(file) => {
let file = BufReader::with_capacity(max_read, file);
// We read new lines using this API instead of the Lines Iterator
// on purpose. We want to keep the `\n` and potential `\r` between each lines
// We use an iterator to be able to chain with par_bridge.
itertools::Either::Left(file.lines_with_ending())
}
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}
}),
|iter| self.train(trainer, iter).map(|_| {}),
)??;
Ok(self)
}
/// Train a model and replace our current Model, using the given Trainer
pub fn train<T, I, S>(&mut self, trainer: &mut T, sequences: I) -> Result<&mut Self>
where
T: Trainer<Model = M> + Sync,
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
{
let (lower, upper) = sequences.size_hint();
let len = upper.unwrap_or(lower) as u64;
let progress = if trainer.should_show_progress() {
let progress = ProgressBar::new(len);
progress.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {percent:>19!}"),
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
);
progress.set_message(&format!("Reading files ({:.2} Mo)", len / 1_000_000));
progress.set_message("Pre-processing sequences");
progress.set_draw_delta(len / 100); // Redraw only every 2%
Some(progress)
} else {
None
};
let words = files
.into_iter()
.map(|filename| -> Result<HashMap<String, u32>> {
let file = File::open(filename)?;
let file = BufReader::with_capacity(max_read, file);
// We read new lines using this API instead of the Lines Iterator
// on purpose. We want to keep the `\n` and potential `\r` between each lines
// We use an iterator to be able to chain with par_bridge.
file.lines_with_ending()
.maybe_par_bridge()
.map_with(
&progress,
|progress, line| -> Result<HashMap<String, u32>> {
let newline = line?;
let b = newline.len();
let mut words = HashMap::new();
let normalized = self.do_normalize(newline)?;
trainer.feed(
sequences.map(|s| {
// if let Some(progress) = &progress {
// progress.inc(1)
// }
s
}),
|seq| {
let normalized = self.do_normalize(seq.as_ref())?;
let pre_tokenized = self.do_pre_tokenize(normalized)?;
trainer.process_tokens(
&mut words,
pre_tokenized
Ok(pre_tokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, _, _)| s.to_owned())
.collect(),
);
if let Some(pbar) = progress {
pbar.inc(b as u64);
}
Ok(words)
},
)
.reduce(
|| Ok(HashMap::new()),
|acc, ws| {
let mut acc = acc?;
for (k, v) in ws? {
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
}
Ok(acc)
},
)
})
.try_fold(
HashMap::new(),
|mut acc, ws| -> Result<HashMap<String, u32>> {
for (k, v) in ws? {
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
}
Ok(acc)
.collect())
},
)?;
if let Some(pbar) = progress {
pbar.finish();
}
Ok(words)
}
/// Train a model and replace our current Model, using the given Trainer
pub fn train<T>(&mut self, trainer: &T, files: Vec<String>) -> Result<&mut Self>
where
T: Trainer<Model = M> + Sync,
{
let words = self.word_count(trainer, files)?;
let special_tokens = trainer.train(words, &mut self.model)?;
let special_tokens = trainer.train(&mut self.model)?;
self.add_special_tokens(&special_tokens);
Ok(self)