mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Add ability to train from Iterator
This commit is contained in:
118
bindings/python/Cargo.lock
generated
118
bindings/python/Cargo.lock
generated
@ -65,6 +65,12 @@ version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "2.33.3"
|
||||
@ -106,27 +112,68 @@ dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "const_fn"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c478836e029dcef17fb47c89023448c64f781a046e0300e257ad8225ae59afab"
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd01a6eb3daaafa260f6fc94c3a6c36390abc2080e38e3e34ced87393fb77d80"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"crossbeam-channel 0.5.0",
|
||||
"crossbeam-deque 0.8.0",
|
||||
"crossbeam-epoch 0.9.0",
|
||||
"crossbeam-queue",
|
||||
"crossbeam-utils 0.8.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-channel"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b153fe7cbef478c567df0f972e02e6d736db11affe43dfc9c56a9374d1adfb87"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
"crossbeam-utils 0.7.2",
|
||||
"maybe-uninit",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-channel"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dca26ee1f8d361640700bde38b2c37d8c22b3ce2d360e1fc1c74ea4b0aa7d775"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"crossbeam-utils 0.8.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f02af974daeee82218205558e51ec8768b48cf524bd01d550abe5573a608285"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
"crossbeam-epoch 0.8.2",
|
||||
"crossbeam-utils 0.7.2",
|
||||
"maybe-uninit",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"crossbeam-epoch 0.9.0",
|
||||
"crossbeam-utils 0.8.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.8.2"
|
||||
@ -134,14 +181,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "058ed274caafc1f60c4997b5fc07bf7dc7cca454af7c6e81edffe5f33f70dace"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"cfg-if 0.1.10",
|
||||
"crossbeam-utils 0.7.2",
|
||||
"lazy_static",
|
||||
"maybe-uninit",
|
||||
"memoffset",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0f606a85340376eef0d6d8fec399e6d4a544d648386c6645eb6d0653b27d9f"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"const_fn",
|
||||
"crossbeam-utils 0.8.0",
|
||||
"lazy_static",
|
||||
"memoffset",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-queue"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b2a58563f049aa3bae172bc4120f093b5901161c629f280a1f40ba55317d774"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"crossbeam-utils 0.8.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.7.2"
|
||||
@ -149,7 +220,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3c7c73a2d1e9fc0886a08b93e98eb643461230d5f1925e4036204d5f2e261a8"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cfg-if",
|
||||
"cfg-if 0.1.10",
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec91540d98355f690a86367e566ecad2e9e579f230230eb7c21398372be73ea5"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cfg-if 1.0.0",
|
||||
"const_fn",
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
@ -269,7 +352,7 @@ version = "0.1.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc587bc0ec293155d5bfa6b9891ec18a1e330c234f896ea47fbada4cadbe47e6"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 0.1.10",
|
||||
"libc",
|
||||
"wasi",
|
||||
]
|
||||
@ -350,7 +433,7 @@ version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "63312a18f7ea8760cdd0a7c5aac1a619752a246b833545e3e36d1f81f7cd9e66"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 0.1.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -413,7 +496,7 @@ checksum = "db65c6da02e61f55dae90a0ae427b2a5f6b3e8db09f58d10efab23af92592616"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"bitflags",
|
||||
"cfg-if",
|
||||
"cfg-if 0.1.10",
|
||||
"ryu",
|
||||
"static_assertions",
|
||||
]
|
||||
@ -439,7 +522,7 @@ version = "0.4.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 0.1.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -546,7 +629,7 @@ name = "numpy"
|
||||
version = "0.11.0"
|
||||
source = "git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01#e331befa27fede78d4662edf08fa0508db39be01"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 0.1.10",
|
||||
"libc",
|
||||
"ndarray",
|
||||
"num-complex",
|
||||
@ -593,7 +676,7 @@ version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c361aa727dd08437f2f1447be8b59a33b0edd15e0fcee698f935613d9efbca9b"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 0.1.10",
|
||||
"cloudabi",
|
||||
"instant",
|
||||
"libc",
|
||||
@ -755,7 +838,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcf6960dc9a5b4ee8d3e4c5787b4a112a8818e0290a42ff664ad60692fdf2032"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"crossbeam-deque",
|
||||
"crossbeam-deque 0.7.3",
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
@ -776,9 +859,9 @@ version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8c4fec834fb6e6d2dd5eece3c7b432a52f0ba887cf40e595190c4107edc08bf"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
"crossbeam-channel 0.4.4",
|
||||
"crossbeam-deque 0.7.3",
|
||||
"crossbeam-utils 0.7.2",
|
||||
"lazy_static",
|
||||
"num_cpus",
|
||||
]
|
||||
@ -912,7 +995,7 @@ version = "3.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 0.1.10",
|
||||
"libc",
|
||||
"rand",
|
||||
"redox_syscall",
|
||||
@ -995,6 +1078,7 @@ dependencies = [
|
||||
name = "tokenizers-python"
|
||||
version = "0.9.4"
|
||||
dependencies = [
|
||||
"crossbeam",
|
||||
"env_logger",
|
||||
"libc",
|
||||
"ndarray",
|
||||
|
@ -18,6 +18,7 @@ pyo3 = "0.12"
|
||||
numpy = { git = "https://github.com/pyo3/rust-numpy/", rev = "e331befa27fede78d4662edf08fa0508db39be01" }
|
||||
ndarray = "0.13"
|
||||
onig = { version = "6.0", default-features = false }
|
||||
crossbeam = "0.8"
|
||||
|
||||
[dependencies.tokenizers]
|
||||
version = "*"
|
||||
|
@ -12,6 +12,7 @@ use tk::tokenizer::{
|
||||
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
|
||||
TruncationParams, TruncationStrategy,
|
||||
};
|
||||
use tk::utils::iter::ResultShunt;
|
||||
use tokenizers as tk;
|
||||
|
||||
use super::decoders::PyDecoder;
|
||||
@ -1069,16 +1070,53 @@ impl PyTokenizer {
|
||||
}
|
||||
|
||||
#[args(trainer = "None")]
|
||||
fn train(&mut self, files: Vec<String>, trainer: Option<&PyTrainer>) -> PyResult<()> {
|
||||
let trainer =
|
||||
fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> {
|
||||
let mut trainer =
|
||||
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| {
|
||||
ToPyResult(self.tokenizer.train(&trainer, files).map(|_| {})).into()
|
||||
ToPyResult(
|
||||
self.tokenizer
|
||||
.train_from_files(&mut trainer, files)
|
||||
.map(|_| {}),
|
||||
)
|
||||
.into()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[args(trainer = "None")]
|
||||
fn train_from_iterator(
|
||||
&mut self,
|
||||
iterator: &PyAny,
|
||||
trainer: Option<&mut PyTrainer>,
|
||||
) -> PyResult<()> {
|
||||
let mut trainer =
|
||||
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(256);
|
||||
let mut sender = Some(send);
|
||||
let iterator: PyIterator = iterator.iter()?;
|
||||
|
||||
crossbeam::thread::scope(|s| {
|
||||
let _train_handle = s.spawn(|_| {
|
||||
self.tokenizer
|
||||
.train(&mut trainer, recv.into_iter())
|
||||
.map(|_| {})
|
||||
});
|
||||
|
||||
ResultShunt::process(iterator.map(|seq| seq?.extract::<&str>()), |iter| {
|
||||
if let Some(send) = sender.take() {
|
||||
for seq in iter {
|
||||
send.send(seq)
|
||||
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})?
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Apply all the post-processing steps to the given encodings.
|
||||
///
|
||||
/// The various steps are:
|
||||
|
@ -1,4 +1,3 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use pyo3::exceptions;
|
||||
@ -50,19 +49,20 @@ impl Trainer for PyTrainer {
|
||||
self.trainer.read().unwrap().should_show_progress()
|
||||
}
|
||||
|
||||
fn train(
|
||||
&self,
|
||||
words: HashMap<String, u32>,
|
||||
model: &mut PyModel,
|
||||
) -> tk::Result<Vec<tk::AddedToken>> {
|
||||
fn train(&self, model: &mut PyModel) -> tk::Result<Vec<tk::AddedToken>> {
|
||||
self.trainer
|
||||
.read()
|
||||
.unwrap()
|
||||
.train(words, &mut model.model.write().unwrap())
|
||||
.train(&mut model.model.write().unwrap())
|
||||
}
|
||||
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
self.trainer.read().unwrap().process_tokens(words, tokens)
|
||||
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> tk::Result<()>
|
||||
where
|
||||
I: Iterator<Item = S> + Send,
|
||||
S: AsRef<str> + Send,
|
||||
F: Fn(&str) -> tk::Result<Vec<String>> + Sync,
|
||||
{
|
||||
self.trainer.write().unwrap().feed(iterator, process)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -133,6 +133,7 @@ impl BpeTrainerBuilder {
|
||||
initial_alphabet: self.config.initial_alphabet,
|
||||
continuing_subword_prefix: self.config.continuing_subword_prefix,
|
||||
end_of_word_suffix: self.config.end_of_word_suffix,
|
||||
words: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -174,6 +175,8 @@ pub struct BpeTrainer {
|
||||
pub continuing_subword_prefix: Option<String>,
|
||||
/// An optional suffix to caracterize and end-of-word subword
|
||||
pub end_of_word_suffix: Option<String>,
|
||||
|
||||
words: HashMap<String, u32>,
|
||||
}
|
||||
|
||||
impl Default for BpeTrainer {
|
||||
@ -407,9 +410,9 @@ impl BpeTrainer {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn train(
|
||||
pub fn do_train(
|
||||
&self,
|
||||
word_counts: HashMap<String, u32>,
|
||||
word_counts: &HashMap<String, u32>,
|
||||
model: &mut BPE,
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
|
||||
@ -425,14 +428,14 @@ impl BpeTrainer {
|
||||
//
|
||||
// 2. Compute the initial alphabet
|
||||
//
|
||||
self.compute_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
|
||||
self.compute_alphabet(word_counts, &mut word_to_id, &mut id_to_word);
|
||||
|
||||
//
|
||||
// 3. Tokenize words
|
||||
//
|
||||
self.update_progress(&progress, word_counts.len(), "Tokenize words");
|
||||
let (words, counts) =
|
||||
self.tokenize_words(&word_counts, &mut word_to_id, &mut id_to_word, &progress);
|
||||
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
|
||||
self.finalize_progress(&progress, words.len());
|
||||
|
||||
//
|
||||
@ -586,14 +589,45 @@ impl Trainer for BpeTrainer {
|
||||
type Model = BPE;
|
||||
|
||||
/// Train a BPE model
|
||||
fn train(&self, word_counts: HashMap<String, u32>, model: &mut BPE) -> Result<Vec<AddedToken>> {
|
||||
self.train(word_counts, model)
|
||||
fn train(&self, model: &mut BPE) -> Result<Vec<AddedToken>> {
|
||||
self.do_train(&self.words, model)
|
||||
}
|
||||
|
||||
/// Whether we should show progress
|
||||
fn should_show_progress(&self) -> bool {
|
||||
self.show_progress
|
||||
}
|
||||
|
||||
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||
where
|
||||
I: Iterator<Item = S> + Send,
|
||||
S: AsRef<str> + Send,
|
||||
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||
{
|
||||
let words: Result<HashMap<String, u32>> = iterator
|
||||
.maybe_par_bridge()
|
||||
.map(|sequence| {
|
||||
let words = process(sequence.as_ref())?;
|
||||
let mut map = HashMap::new();
|
||||
for word in words {
|
||||
map.entry(word).and_modify(|c| *c += 1).or_insert(1);
|
||||
}
|
||||
Ok(map)
|
||||
})
|
||||
.reduce(
|
||||
|| Ok(HashMap::new()),
|
||||
|acc, ws| {
|
||||
let mut acc = acc?;
|
||||
for (k, v) in ws? {
|
||||
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
||||
}
|
||||
Ok(acc)
|
||||
},
|
||||
);
|
||||
|
||||
self.words = words?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -624,7 +658,7 @@ mod tests {
|
||||
.min_frequency(2)
|
||||
.build();
|
||||
let mut model = BPE::default();
|
||||
trainer.train(word_counts, &mut model).unwrap();
|
||||
trainer.do_train(&word_counts, &mut model).unwrap();
|
||||
|
||||
// Vocab should contain all of the characters from the `word_counts` mapping
|
||||
// as well as three merges: 're', 'are', and 'is'.
|
||||
|
@ -145,37 +145,38 @@ impl Trainer for TrainerWrapper {
|
||||
}
|
||||
}
|
||||
|
||||
fn train(
|
||||
&self,
|
||||
words: HashMap<String, u32>,
|
||||
model: &mut ModelWrapper,
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
fn train(&self, model: &mut ModelWrapper) -> Result<Vec<AddedToken>> {
|
||||
match self {
|
||||
TrainerWrapper::BpeTrainer(t) => match model {
|
||||
ModelWrapper::BPE(bpe) => t.train(words, bpe),
|
||||
ModelWrapper::BPE(bpe) => t.train(bpe),
|
||||
_ => Err("BpeTrainer can only train a BPE".into()),
|
||||
},
|
||||
TrainerWrapper::WordPieceTrainer(t) => match model {
|
||||
ModelWrapper::WordPiece(wp) => t.train(words, wp),
|
||||
ModelWrapper::WordPiece(wp) => t.train(wp),
|
||||
_ => Err("WordPieceTrainer can only train a WordPiece".into()),
|
||||
},
|
||||
TrainerWrapper::WordLevelTrainer(t) => match model {
|
||||
ModelWrapper::WordLevel(wl) => t.train(words, wl),
|
||||
ModelWrapper::WordLevel(wl) => t.train(wl),
|
||||
_ => Err("WordLevelTrainer can only train a WordLevel".into()),
|
||||
},
|
||||
TrainerWrapper::UnigramTrainer(t) => match model {
|
||||
ModelWrapper::Unigram(u) => t.train(words, u),
|
||||
ModelWrapper::Unigram(u) => t.train(u),
|
||||
_ => Err("UnigramTrainer can only train a Unigram".into()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||
where
|
||||
I: Iterator<Item = S> + Send,
|
||||
S: AsRef<str> + Send,
|
||||
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||
{
|
||||
match self {
|
||||
TrainerWrapper::BpeTrainer(bpe) => bpe.process_tokens(words, tokens),
|
||||
TrainerWrapper::WordPieceTrainer(wpt) => wpt.process_tokens(words, tokens),
|
||||
TrainerWrapper::WordLevelTrainer(wpt) => wpt.process_tokens(words, tokens),
|
||||
TrainerWrapper::UnigramTrainer(wpt) => wpt.process_tokens(words, tokens),
|
||||
TrainerWrapper::BpeTrainer(bpe) => bpe.feed(iterator, process),
|
||||
TrainerWrapper::WordPieceTrainer(wpt) => wpt.feed(iterator, process),
|
||||
TrainerWrapper::WordLevelTrainer(wpt) => wpt.feed(iterator, process),
|
||||
TrainerWrapper::UnigramTrainer(wpt) => wpt.feed(iterator, process),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -194,7 +195,7 @@ mod tests {
|
||||
let trainer = TrainerWrapper::BpeTrainer(BpeTrainer::default());
|
||||
let mut model = ModelWrapper::Unigram(Unigram::default());
|
||||
|
||||
let result = trainer.train(HashMap::new(), &mut model);
|
||||
let result = trainer.train(&mut model);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::models::unigram::{lattice::Lattice, model::Unigram};
|
||||
use crate::tokenizer::{AddedToken, Result, Trainer};
|
||||
use crate::utils::parallelism::*;
|
||||
use crate::utils::progress::{ProgressBar, ProgressStyle};
|
||||
use log::debug;
|
||||
use std::cmp::Reverse;
|
||||
@ -58,7 +59,9 @@ pub struct UnigramTrainer {
|
||||
#[builder(default = "16")]
|
||||
pub max_piece_length: usize,
|
||||
#[builder(default = "1_000_000")]
|
||||
pub seed_size: usize,
|
||||
seed_size: usize,
|
||||
#[builder(default = "HashMap::new()")]
|
||||
words: HashMap<String, u32>,
|
||||
}
|
||||
|
||||
impl Default for UnigramTrainer {
|
||||
@ -451,7 +454,11 @@ impl UnigramTrainer {
|
||||
.collect();
|
||||
new_pieces
|
||||
}
|
||||
pub fn _train(&self, sentences: Vec<Sentence>, model: &mut Unigram) -> Result<Vec<AddedToken>> {
|
||||
pub fn do_train(
|
||||
&self,
|
||||
sentences: Vec<Sentence>,
|
||||
model: &mut Unigram,
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
let progress = self.setup_progress();
|
||||
//
|
||||
// 1. Compute frequent substrings
|
||||
@ -533,19 +540,46 @@ impl Trainer for UnigramTrainer {
|
||||
type Model = Unigram;
|
||||
|
||||
/// Train a Unigram model
|
||||
fn train(
|
||||
&self,
|
||||
word_counts: HashMap<String, u32>,
|
||||
model: &mut Unigram,
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
let sentences: Vec<_> = word_counts.into_iter().collect();
|
||||
self._train(sentences, model)
|
||||
fn train(&self, model: &mut Unigram) -> Result<Vec<AddedToken>> {
|
||||
let sentences: Vec<_> = self.words.iter().map(|(s, i)| (s.to_owned(), *i)).collect();
|
||||
self.do_train(sentences, model)
|
||||
}
|
||||
|
||||
/// Whether we should show progress
|
||||
fn should_show_progress(&self) -> bool {
|
||||
self.show_progress
|
||||
}
|
||||
|
||||
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||
where
|
||||
I: Iterator<Item = S> + Send,
|
||||
S: AsRef<str> + Send,
|
||||
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||
{
|
||||
let words: Result<HashMap<String, u32>> = iterator
|
||||
.maybe_par_bridge()
|
||||
.map(|sequence| {
|
||||
let words = process(sequence.as_ref())?;
|
||||
let mut map = HashMap::new();
|
||||
for word in words {
|
||||
map.entry(word).and_modify(|c| *c += 1).or_insert(1);
|
||||
}
|
||||
Ok(map)
|
||||
})
|
||||
.reduce(
|
||||
|| Ok(HashMap::new()),
|
||||
|acc, ws| {
|
||||
let mut acc = acc?;
|
||||
for (k, v) in ws? {
|
||||
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
||||
}
|
||||
Ok(acc)
|
||||
},
|
||||
);
|
||||
|
||||
self.words = words?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -640,10 +674,7 @@ mod tests {
|
||||
|
||||
let mut unigram = Unigram::default();
|
||||
trainer
|
||||
.train(
|
||||
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
|
||||
&mut unigram,
|
||||
)
|
||||
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
|
||||
.unwrap();
|
||||
|
||||
let mut pieces = unigram.iter();
|
||||
@ -665,10 +696,7 @@ mod tests {
|
||||
|
||||
let mut unigram = Unigram::default();
|
||||
trainer
|
||||
.train(
|
||||
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
|
||||
&mut unigram,
|
||||
)
|
||||
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
|
||||
.unwrap();
|
||||
|
||||
let mut pieces = unigram.iter();
|
||||
@ -684,10 +712,7 @@ mod tests {
|
||||
|
||||
let mut unigram = Unigram::default();
|
||||
trainer
|
||||
.train(
|
||||
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
|
||||
&mut unigram,
|
||||
)
|
||||
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
|
||||
.unwrap();
|
||||
|
||||
let mut pieces = unigram.iter();
|
||||
@ -707,10 +732,7 @@ mod tests {
|
||||
|
||||
let mut unigram = Unigram::default();
|
||||
trainer
|
||||
.train(
|
||||
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
|
||||
&mut unigram,
|
||||
)
|
||||
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
|
||||
.unwrap();
|
||||
|
||||
let mut pieces = unigram.iter();
|
||||
|
@ -1,4 +1,5 @@
|
||||
use super::WordLevel;
|
||||
use crate::utils::parallelism::*;
|
||||
use crate::{AddedToken, Result, Trainer};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -17,6 +18,9 @@ pub struct WordLevelTrainer {
|
||||
/// A list of special tokens that the model should know of
|
||||
#[builder(default)]
|
||||
pub special_tokens: Vec<AddedToken>,
|
||||
|
||||
#[builder(default, private)]
|
||||
words: HashMap<String, u32>,
|
||||
}
|
||||
|
||||
impl Default for WordLevelTrainer {
|
||||
@ -26,6 +30,7 @@ impl Default for WordLevelTrainer {
|
||||
vocab_size: 30_000,
|
||||
show_progress: true,
|
||||
special_tokens: vec![],
|
||||
words: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -35,12 +40,12 @@ impl WordLevelTrainer {
|
||||
WordLevelTrainerBuilder::default()
|
||||
}
|
||||
|
||||
fn train(
|
||||
fn do_train(
|
||||
&self,
|
||||
word_counts: HashMap<String, u32>,
|
||||
word_counts: &HashMap<String, u32>,
|
||||
model: &mut WordLevel,
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
let mut ordered_counts = word_counts.into_iter().collect::<Vec<_>>();
|
||||
let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();
|
||||
ordered_counts.sort_by_key(|(_, n)| std::cmp::Reverse(*n));
|
||||
let word_level = WordLevel::builder()
|
||||
.vocab(
|
||||
@ -50,8 +55,8 @@ impl WordLevelTrainer {
|
||||
.chain(
|
||||
ordered_counts
|
||||
.into_iter()
|
||||
.filter(|(_, n)| *n >= self.min_frequency)
|
||||
.map(|(w, _)| w),
|
||||
.filter(|(_, n)| **n >= self.min_frequency)
|
||||
.map(|(w, _)| w.to_owned()),
|
||||
)
|
||||
.take(self.vocab_size)
|
||||
.enumerate()
|
||||
@ -72,18 +77,45 @@ impl Trainer for WordLevelTrainer {
|
||||
type Model = WordLevel;
|
||||
|
||||
/// Train a WordLevel model
|
||||
fn train(
|
||||
&self,
|
||||
word_counts: HashMap<String, u32>,
|
||||
model: &mut WordLevel,
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
self.train(word_counts, model)
|
||||
fn train(&self, model: &mut WordLevel) -> Result<Vec<AddedToken>> {
|
||||
self.do_train(&self.words, model)
|
||||
}
|
||||
|
||||
/// Whether we should show progress
|
||||
fn should_show_progress(&self) -> bool {
|
||||
self.show_progress
|
||||
}
|
||||
|
||||
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||
where
|
||||
I: Iterator<Item = S> + Send,
|
||||
S: AsRef<str> + Send,
|
||||
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||
{
|
||||
let words: Result<HashMap<String, u32>> = iterator
|
||||
.maybe_par_bridge()
|
||||
.map(|sequence| {
|
||||
let words = process(sequence.as_ref())?;
|
||||
let mut map = HashMap::new();
|
||||
for word in words {
|
||||
map.entry(word).and_modify(|c| *c += 1).or_insert(1);
|
||||
}
|
||||
Ok(map)
|
||||
})
|
||||
.reduce(
|
||||
|| Ok(HashMap::new()),
|
||||
|acc, ws| {
|
||||
let mut acc = acc?;
|
||||
for (k, v) in ws? {
|
||||
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
||||
}
|
||||
Ok(acc)
|
||||
},
|
||||
);
|
||||
|
||||
self.words = words?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -108,7 +140,7 @@ mod tests {
|
||||
trainer.vocab_size = 5;
|
||||
|
||||
let mut model = WordLevel::default();
|
||||
trainer.train(word_counts.clone(), &mut model).unwrap();
|
||||
trainer.do_train(&word_counts, &mut model).unwrap();
|
||||
let expected_vocab: HashMap<String, u32> = [
|
||||
("the".into(), 0),
|
||||
("are".into(), 1),
|
||||
@ -124,7 +156,7 @@ mod tests {
|
||||
// If we specify a min_frequency
|
||||
trainer.min_frequency = 15;
|
||||
let mut model = WordLevel::default();
|
||||
trainer.train(word_counts, &mut model).unwrap();
|
||||
trainer.do_train(&word_counts, &mut model).unwrap();
|
||||
let expected_vocab: HashMap<String, u32> = [
|
||||
("the".into(), 0),
|
||||
("are".into(), 1),
|
||||
|
@ -1,7 +1,7 @@
|
||||
use super::WordPiece;
|
||||
use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder, BPE};
|
||||
use crate::tokenizer::{AddedToken, Result, Trainer};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// A `WordPieceTrainerBuilder` can be used to create a `WordPieceTrainer` with a custom
|
||||
/// configuration.
|
||||
@ -153,13 +153,9 @@ impl WordPieceTrainer {
|
||||
WordPieceTrainerBuilder::default()
|
||||
}
|
||||
|
||||
pub fn train(
|
||||
&self,
|
||||
word_counts: HashMap<String, u32>,
|
||||
model: &mut WordPiece,
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
pub fn train(&self, model: &mut WordPiece) -> Result<Vec<AddedToken>> {
|
||||
let mut bpe = BPE::default();
|
||||
let special_tokens = self.bpe_trainer.train(word_counts, &mut bpe)?;
|
||||
let special_tokens = self.bpe_trainer.train(&mut bpe)?;
|
||||
let new_wordpiece = WordPiece::from_bpe(&bpe);
|
||||
|
||||
// Transfer the vocab
|
||||
@ -175,19 +171,20 @@ impl WordPieceTrainer {
|
||||
impl Trainer for WordPieceTrainer {
|
||||
type Model = WordPiece;
|
||||
|
||||
fn train(
|
||||
&self,
|
||||
word_counts: HashMap<String, u32>,
|
||||
model: &mut WordPiece,
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
self.train(word_counts, model)
|
||||
}
|
||||
|
||||
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
self.bpe_trainer.process_tokens(&mut words, tokens)
|
||||
fn train(&self, model: &mut WordPiece) -> Result<Vec<AddedToken>> {
|
||||
self.train(model)
|
||||
}
|
||||
|
||||
fn should_show_progress(&self) -> bool {
|
||||
self.bpe_trainer.should_show_progress()
|
||||
}
|
||||
|
||||
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||
where
|
||||
I: Iterator<Item = S> + Send,
|
||||
S: AsRef<str> + Send,
|
||||
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||
{
|
||||
self.bpe_trainer.feed(iterator, process)
|
||||
}
|
||||
}
|
||||
|
@ -532,11 +532,15 @@ mod tests {
|
||||
fn should_show_progress(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn train(
|
||||
&self,
|
||||
_words: HashMap<String, u32>,
|
||||
_model: &mut ModelMock,
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
fn train(&self, _model: &mut ModelMock) -> Result<Vec<AddedToken>> {
|
||||
unimplemented!()
|
||||
}
|
||||
fn feed<I, S, F>(&mut self, _iterator: I, _process: F) -> Result<()>
|
||||
where
|
||||
I: Iterator<Item = S> + Send,
|
||||
S: AsRef<str> + Send,
|
||||
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||
{
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
@ -131,20 +131,13 @@ pub trait Trainer {
|
||||
fn should_show_progress(&self) -> bool;
|
||||
/// The actual training method. This will return a new trained Model as well as a list
|
||||
/// of `special_tokens` to be added directly to the tokenizer along with the model.
|
||||
fn train(
|
||||
&self,
|
||||
words: HashMap<String, u32>,
|
||||
model: &mut Self::Model,
|
||||
) -> Result<Vec<AddedToken>>;
|
||||
/// Process a bunch of token, counting them as relevant.
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
for token in tokens {
|
||||
words
|
||||
.entry(token.clone())
|
||||
.and_modify(|c| *c += 1)
|
||||
.or_insert(1);
|
||||
}
|
||||
}
|
||||
fn train(&self, model: &mut Self::Model) -> Result<Vec<AddedToken>>;
|
||||
/// Process an iterator of sequences already pre-processed by the Tokenizer
|
||||
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||
where
|
||||
I: Iterator<Item = S> + Send,
|
||||
S: AsRef<str> + Send,
|
||||
F: Fn(&str) -> Result<Vec<String>> + Sync;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -969,99 +962,74 @@ where
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Train a model and replace our current Model, using the given Trainer
|
||||
fn word_count<MN, T>(&self, trainer: &T, files: Vec<String>) -> Result<HashMap<String, u32>>
|
||||
pub fn train_from_files<T>(&mut self, trainer: &mut T, files: Vec<String>) -> Result<&mut Self>
|
||||
where
|
||||
T: Trainer<Model = MN> + Sync,
|
||||
MN: Model,
|
||||
T: Trainer<Model = M> + Sync,
|
||||
{
|
||||
let max_read = 1_000_000;
|
||||
let mut len = 0;
|
||||
for file in files.iter() {
|
||||
len += File::open(file)
|
||||
.and_then(|f| f.metadata())
|
||||
.map(|m| m.len())?;
|
||||
use crate::utils::iter::ResultShunt;
|
||||
ResultShunt::process(
|
||||
files.into_iter().flat_map(|filename| {
|
||||
match File::open(filename) {
|
||||
Ok(file) => {
|
||||
let file = BufReader::with_capacity(max_read, file);
|
||||
// We read new lines using this API instead of the Lines Iterator
|
||||
// on purpose. We want to keep the `\n` and potential `\r` between each lines
|
||||
// We use an iterator to be able to chain with par_bridge.
|
||||
itertools::Either::Left(file.lines_with_ending())
|
||||
}
|
||||
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
|
||||
}
|
||||
}),
|
||||
|iter| self.train(trainer, iter).map(|_| {}),
|
||||
)??;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Train a model and replace our current Model, using the given Trainer
|
||||
pub fn train<T, I, S>(&mut self, trainer: &mut T, sequences: I) -> Result<&mut Self>
|
||||
where
|
||||
T: Trainer<Model = M> + Sync,
|
||||
I: Iterator<Item = S> + Send,
|
||||
S: AsRef<str> + Send,
|
||||
{
|
||||
let (lower, upper) = sequences.size_hint();
|
||||
let len = upper.unwrap_or(lower) as u64;
|
||||
let progress = if trainer.should_show_progress() {
|
||||
let progress = ProgressBar::new(len);
|
||||
progress.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {percent:>19!}"),
|
||||
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
|
||||
);
|
||||
progress.set_message(&format!("Reading files ({:.2} Mo)", len / 1_000_000));
|
||||
progress.set_message("Pre-processing sequences");
|
||||
progress.set_draw_delta(len / 100); // Redraw only every 2%
|
||||
Some(progress)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let words = files
|
||||
.into_iter()
|
||||
.map(|filename| -> Result<HashMap<String, u32>> {
|
||||
let file = File::open(filename)?;
|
||||
let file = BufReader::with_capacity(max_read, file);
|
||||
// We read new lines using this API instead of the Lines Iterator
|
||||
// on purpose. We want to keep the `\n` and potential `\r` between each lines
|
||||
// We use an iterator to be able to chain with par_bridge.
|
||||
file.lines_with_ending()
|
||||
.maybe_par_bridge()
|
||||
.map_with(
|
||||
&progress,
|
||||
|progress, line| -> Result<HashMap<String, u32>> {
|
||||
let newline = line?;
|
||||
let b = newline.len();
|
||||
let mut words = HashMap::new();
|
||||
let normalized = self.do_normalize(newline)?;
|
||||
|
||||
trainer.feed(
|
||||
sequences.map(|s| {
|
||||
// if let Some(progress) = &progress {
|
||||
// progress.inc(1)
|
||||
// }
|
||||
s
|
||||
}),
|
||||
|seq| {
|
||||
let normalized = self.do_normalize(seq.as_ref())?;
|
||||
let pre_tokenized = self.do_pre_tokenize(normalized)?;
|
||||
trainer.process_tokens(
|
||||
&mut words,
|
||||
pre_tokenized
|
||||
Ok(pre_tokenized
|
||||
.get_splits(OffsetReferential::Original, OffsetType::Byte)
|
||||
.into_iter()
|
||||
.map(|(s, _, _)| s.to_owned())
|
||||
.collect(),
|
||||
);
|
||||
|
||||
if let Some(pbar) = progress {
|
||||
pbar.inc(b as u64);
|
||||
}
|
||||
Ok(words)
|
||||
},
|
||||
)
|
||||
.reduce(
|
||||
|| Ok(HashMap::new()),
|
||||
|acc, ws| {
|
||||
let mut acc = acc?;
|
||||
for (k, v) in ws? {
|
||||
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
||||
}
|
||||
Ok(acc)
|
||||
},
|
||||
)
|
||||
})
|
||||
.try_fold(
|
||||
HashMap::new(),
|
||||
|mut acc, ws| -> Result<HashMap<String, u32>> {
|
||||
for (k, v) in ws? {
|
||||
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
||||
}
|
||||
Ok(acc)
|
||||
.collect())
|
||||
},
|
||||
)?;
|
||||
if let Some(pbar) = progress {
|
||||
pbar.finish();
|
||||
}
|
||||
Ok(words)
|
||||
}
|
||||
|
||||
/// Train a model and replace our current Model, using the given Trainer
|
||||
pub fn train<T>(&mut self, trainer: &T, files: Vec<String>) -> Result<&mut Self>
|
||||
where
|
||||
T: Trainer<Model = M> + Sync,
|
||||
{
|
||||
let words = self.word_count(trainer, files)?;
|
||||
|
||||
let special_tokens = trainer.train(words, &mut self.model)?;
|
||||
let special_tokens = trainer.train(&mut self.model)?;
|
||||
self.add_special_tokens(&special_tokens);
|
||||
|
||||
Ok(self)
|
||||
|
Reference in New Issue
Block a user