mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-02 07:19:24 +00:00
Add ability to train from Iterator
This commit is contained in:
118
bindings/python/Cargo.lock
generated
118
bindings/python/Cargo.lock
generated
@ -65,6 +65,12 @@ version = "0.1.10"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cfg-if"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap"
|
name = "clap"
|
||||||
version = "2.33.3"
|
version = "2.33.3"
|
||||||
@ -106,27 +112,68 @@ dependencies = [
|
|||||||
"winapi-util",
|
"winapi-util",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "const_fn"
|
||||||
|
version = "0.4.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c478836e029dcef17fb47c89023448c64f781a046e0300e257ad8225ae59afab"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam"
|
||||||
|
version = "0.8.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fd01a6eb3daaafa260f6fc94c3a6c36390abc2080e38e3e34ced87393fb77d80"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"crossbeam-channel 0.5.0",
|
||||||
|
"crossbeam-deque 0.8.0",
|
||||||
|
"crossbeam-epoch 0.9.0",
|
||||||
|
"crossbeam-queue",
|
||||||
|
"crossbeam-utils 0.8.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-channel"
|
name = "crossbeam-channel"
|
||||||
version = "0.4.4"
|
version = "0.4.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b153fe7cbef478c567df0f972e02e6d736db11affe43dfc9c56a9374d1adfb87"
|
checksum = "b153fe7cbef478c567df0f972e02e6d736db11affe43dfc9c56a9374d1adfb87"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crossbeam-utils",
|
"crossbeam-utils 0.7.2",
|
||||||
"maybe-uninit",
|
"maybe-uninit",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-channel"
|
||||||
|
version = "0.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dca26ee1f8d361640700bde38b2c37d8c22b3ce2d360e1fc1c74ea4b0aa7d775"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"crossbeam-utils 0.8.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-deque"
|
name = "crossbeam-deque"
|
||||||
version = "0.7.3"
|
version = "0.7.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9f02af974daeee82218205558e51ec8768b48cf524bd01d550abe5573a608285"
|
checksum = "9f02af974daeee82218205558e51ec8768b48cf524bd01d550abe5573a608285"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crossbeam-epoch",
|
"crossbeam-epoch 0.8.2",
|
||||||
"crossbeam-utils",
|
"crossbeam-utils 0.7.2",
|
||||||
"maybe-uninit",
|
"maybe-uninit",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-deque"
|
||||||
|
version = "0.8.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"crossbeam-epoch 0.9.0",
|
||||||
|
"crossbeam-utils 0.8.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-epoch"
|
name = "crossbeam-epoch"
|
||||||
version = "0.8.2"
|
version = "0.8.2"
|
||||||
@ -134,14 +181,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "058ed274caafc1f60c4997b5fc07bf7dc7cca454af7c6e81edffe5f33f70dace"
|
checksum = "058ed274caafc1f60c4997b5fc07bf7dc7cca454af7c6e81edffe5f33f70dace"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"cfg-if",
|
"cfg-if 0.1.10",
|
||||||
"crossbeam-utils",
|
"crossbeam-utils 0.7.2",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"maybe-uninit",
|
"maybe-uninit",
|
||||||
"memoffset",
|
"memoffset",
|
||||||
"scopeguard",
|
"scopeguard",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-epoch"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ec0f606a85340376eef0d6d8fec399e6d4a544d648386c6645eb6d0653b27d9f"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"const_fn",
|
||||||
|
"crossbeam-utils 0.8.0",
|
||||||
|
"lazy_static",
|
||||||
|
"memoffset",
|
||||||
|
"scopeguard",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-queue"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6b2a58563f049aa3bae172bc4120f093b5901161c629f280a1f40ba55317d774"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"crossbeam-utils 0.8.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-utils"
|
name = "crossbeam-utils"
|
||||||
version = "0.7.2"
|
version = "0.7.2"
|
||||||
@ -149,7 +220,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "c3c7c73a2d1e9fc0886a08b93e98eb643461230d5f1925e4036204d5f2e261a8"
|
checksum = "c3c7c73a2d1e9fc0886a08b93e98eb643461230d5f1925e4036204d5f2e261a8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"cfg-if",
|
"cfg-if 0.1.10",
|
||||||
|
"lazy_static",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-utils"
|
||||||
|
version = "0.8.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ec91540d98355f690a86367e566ecad2e9e579f230230eb7c21398372be73ea5"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"const_fn",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -269,7 +352,7 @@ version = "0.1.15"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fc587bc0ec293155d5bfa6b9891ec18a1e330c234f896ea47fbada4cadbe47e6"
|
checksum = "fc587bc0ec293155d5bfa6b9891ec18a1e330c234f896ea47fbada4cadbe47e6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if 0.1.10",
|
||||||
"libc",
|
"libc",
|
||||||
"wasi",
|
"wasi",
|
||||||
]
|
]
|
||||||
@ -350,7 +433,7 @@ version = "0.1.7"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "63312a18f7ea8760cdd0a7c5aac1a619752a246b833545e3e36d1f81f7cd9e66"
|
checksum = "63312a18f7ea8760cdd0a7c5aac1a619752a246b833545e3e36d1f81f7cd9e66"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if 0.1.10",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -413,7 +496,7 @@ checksum = "db65c6da02e61f55dae90a0ae427b2a5f6b3e8db09f58d10efab23af92592616"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"arrayvec",
|
"arrayvec",
|
||||||
"bitflags",
|
"bitflags",
|
||||||
"cfg-if",
|
"cfg-if 0.1.10",
|
||||||
"ryu",
|
"ryu",
|
||||||
"static_assertions",
|
"static_assertions",
|
||||||
]
|
]
|
||||||
@ -439,7 +522,7 @@ version = "0.4.11"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b"
|
checksum = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if 0.1.10",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -546,7 +629,7 @@ name = "numpy"
|
|||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
source = "git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01#e331befa27fede78d4662edf08fa0508db39be01"
|
source = "git+https://github.com/pyo3/rust-numpy/?rev=e331befa27fede78d4662edf08fa0508db39be01#e331befa27fede78d4662edf08fa0508db39be01"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if 0.1.10",
|
||||||
"libc",
|
"libc",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"num-complex",
|
"num-complex",
|
||||||
@ -593,7 +676,7 @@ version = "0.8.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c361aa727dd08437f2f1447be8b59a33b0edd15e0fcee698f935613d9efbca9b"
|
checksum = "c361aa727dd08437f2f1447be8b59a33b0edd15e0fcee698f935613d9efbca9b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if 0.1.10",
|
||||||
"cloudabi",
|
"cloudabi",
|
||||||
"instant",
|
"instant",
|
||||||
"libc",
|
"libc",
|
||||||
@ -755,7 +838,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "dcf6960dc9a5b4ee8d3e4c5787b4a112a8818e0290a42ff664ad60692fdf2032"
|
checksum = "dcf6960dc9a5b4ee8d3e4c5787b4a112a8818e0290a42ff664ad60692fdf2032"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"crossbeam-deque",
|
"crossbeam-deque 0.7.3",
|
||||||
"either",
|
"either",
|
||||||
"rayon-core",
|
"rayon-core",
|
||||||
]
|
]
|
||||||
@ -776,9 +859,9 @@ version = "1.8.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e8c4fec834fb6e6d2dd5eece3c7b432a52f0ba887cf40e595190c4107edc08bf"
|
checksum = "e8c4fec834fb6e6d2dd5eece3c7b432a52f0ba887cf40e595190c4107edc08bf"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crossbeam-channel",
|
"crossbeam-channel 0.4.4",
|
||||||
"crossbeam-deque",
|
"crossbeam-deque 0.7.3",
|
||||||
"crossbeam-utils",
|
"crossbeam-utils 0.7.2",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"num_cpus",
|
"num_cpus",
|
||||||
]
|
]
|
||||||
@ -912,7 +995,7 @@ version = "3.1.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9"
|
checksum = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if 0.1.10",
|
||||||
"libc",
|
"libc",
|
||||||
"rand",
|
"rand",
|
||||||
"redox_syscall",
|
"redox_syscall",
|
||||||
@ -995,6 +1078,7 @@ dependencies = [
|
|||||||
name = "tokenizers-python"
|
name = "tokenizers-python"
|
||||||
version = "0.9.4"
|
version = "0.9.4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"crossbeam",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"libc",
|
"libc",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
|
@ -18,6 +18,7 @@ pyo3 = "0.12"
|
|||||||
numpy = { git = "https://github.com/pyo3/rust-numpy/", rev = "e331befa27fede78d4662edf08fa0508db39be01" }
|
numpy = { git = "https://github.com/pyo3/rust-numpy/", rev = "e331befa27fede78d4662edf08fa0508db39be01" }
|
||||||
ndarray = "0.13"
|
ndarray = "0.13"
|
||||||
onig = { version = "6.0", default-features = false }
|
onig = { version = "6.0", default-features = false }
|
||||||
|
crossbeam = "0.8"
|
||||||
|
|
||||||
[dependencies.tokenizers]
|
[dependencies.tokenizers]
|
||||||
version = "*"
|
version = "*"
|
||||||
|
@ -12,6 +12,7 @@ use tk::tokenizer::{
|
|||||||
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
|
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
|
||||||
TruncationParams, TruncationStrategy,
|
TruncationParams, TruncationStrategy,
|
||||||
};
|
};
|
||||||
|
use tk::utils::iter::ResultShunt;
|
||||||
use tokenizers as tk;
|
use tokenizers as tk;
|
||||||
|
|
||||||
use super::decoders::PyDecoder;
|
use super::decoders::PyDecoder;
|
||||||
@ -1069,16 +1070,53 @@ impl PyTokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[args(trainer = "None")]
|
#[args(trainer = "None")]
|
||||||
fn train(&mut self, files: Vec<String>, trainer: Option<&PyTrainer>) -> PyResult<()> {
|
fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> {
|
||||||
let trainer =
|
let mut trainer =
|
||||||
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
|
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
py.allow_threads(|| {
|
py.allow_threads(|| {
|
||||||
ToPyResult(self.tokenizer.train(&trainer, files).map(|_| {})).into()
|
ToPyResult(
|
||||||
|
self.tokenizer
|
||||||
|
.train_from_files(&mut trainer, files)
|
||||||
|
.map(|_| {}),
|
||||||
|
)
|
||||||
|
.into()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[args(trainer = "None")]
|
||||||
|
fn train_from_iterator(
|
||||||
|
&mut self,
|
||||||
|
iterator: &PyAny,
|
||||||
|
trainer: Option<&mut PyTrainer>,
|
||||||
|
) -> PyResult<()> {
|
||||||
|
let mut trainer =
|
||||||
|
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
|
||||||
|
let (send, recv) = std::sync::mpsc::sync_channel(256);
|
||||||
|
let mut sender = Some(send);
|
||||||
|
let iterator: PyIterator = iterator.iter()?;
|
||||||
|
|
||||||
|
crossbeam::thread::scope(|s| {
|
||||||
|
let _train_handle = s.spawn(|_| {
|
||||||
|
self.tokenizer
|
||||||
|
.train(&mut trainer, recv.into_iter())
|
||||||
|
.map(|_| {})
|
||||||
|
});
|
||||||
|
|
||||||
|
ResultShunt::process(iterator.map(|seq| seq?.extract::<&str>()), |iter| {
|
||||||
|
if let Some(send) = sender.take() {
|
||||||
|
for seq in iter {
|
||||||
|
send.send(seq)
|
||||||
|
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
})?
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
/// Apply all the post-processing steps to the given encodings.
|
/// Apply all the post-processing steps to the given encodings.
|
||||||
///
|
///
|
||||||
/// The various steps are:
|
/// The various steps are:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
use pyo3::exceptions;
|
use pyo3::exceptions;
|
||||||
@ -50,19 +49,20 @@ impl Trainer for PyTrainer {
|
|||||||
self.trainer.read().unwrap().should_show_progress()
|
self.trainer.read().unwrap().should_show_progress()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn train(
|
fn train(&self, model: &mut PyModel) -> tk::Result<Vec<tk::AddedToken>> {
|
||||||
&self,
|
|
||||||
words: HashMap<String, u32>,
|
|
||||||
model: &mut PyModel,
|
|
||||||
) -> tk::Result<Vec<tk::AddedToken>> {
|
|
||||||
self.trainer
|
self.trainer
|
||||||
.read()
|
.read()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.train(words, &mut model.model.write().unwrap())
|
.train(&mut model.model.write().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> tk::Result<()>
|
||||||
self.trainer.read().unwrap().process_tokens(words, tokens)
|
where
|
||||||
|
I: Iterator<Item = S> + Send,
|
||||||
|
S: AsRef<str> + Send,
|
||||||
|
F: Fn(&str) -> tk::Result<Vec<String>> + Sync,
|
||||||
|
{
|
||||||
|
self.trainer.write().unwrap().feed(iterator, process)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,6 +133,7 @@ impl BpeTrainerBuilder {
|
|||||||
initial_alphabet: self.config.initial_alphabet,
|
initial_alphabet: self.config.initial_alphabet,
|
||||||
continuing_subword_prefix: self.config.continuing_subword_prefix,
|
continuing_subword_prefix: self.config.continuing_subword_prefix,
|
||||||
end_of_word_suffix: self.config.end_of_word_suffix,
|
end_of_word_suffix: self.config.end_of_word_suffix,
|
||||||
|
words: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -174,6 +175,8 @@ pub struct BpeTrainer {
|
|||||||
pub continuing_subword_prefix: Option<String>,
|
pub continuing_subword_prefix: Option<String>,
|
||||||
/// An optional suffix to caracterize and end-of-word subword
|
/// An optional suffix to caracterize and end-of-word subword
|
||||||
pub end_of_word_suffix: Option<String>,
|
pub end_of_word_suffix: Option<String>,
|
||||||
|
|
||||||
|
words: HashMap<String, u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for BpeTrainer {
|
impl Default for BpeTrainer {
|
||||||
@ -407,9 +410,9 @@ impl BpeTrainer {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn train(
|
pub fn do_train(
|
||||||
&self,
|
&self,
|
||||||
word_counts: HashMap<String, u32>,
|
word_counts: &HashMap<String, u32>,
|
||||||
model: &mut BPE,
|
model: &mut BPE,
|
||||||
) -> Result<Vec<AddedToken>> {
|
) -> Result<Vec<AddedToken>> {
|
||||||
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
|
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
|
||||||
@ -425,14 +428,14 @@ impl BpeTrainer {
|
|||||||
//
|
//
|
||||||
// 2. Compute the initial alphabet
|
// 2. Compute the initial alphabet
|
||||||
//
|
//
|
||||||
self.compute_alphabet(&word_counts, &mut word_to_id, &mut id_to_word);
|
self.compute_alphabet(word_counts, &mut word_to_id, &mut id_to_word);
|
||||||
|
|
||||||
//
|
//
|
||||||
// 3. Tokenize words
|
// 3. Tokenize words
|
||||||
//
|
//
|
||||||
self.update_progress(&progress, word_counts.len(), "Tokenize words");
|
self.update_progress(&progress, word_counts.len(), "Tokenize words");
|
||||||
let (words, counts) =
|
let (words, counts) =
|
||||||
self.tokenize_words(&word_counts, &mut word_to_id, &mut id_to_word, &progress);
|
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
|
||||||
self.finalize_progress(&progress, words.len());
|
self.finalize_progress(&progress, words.len());
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -586,14 +589,45 @@ impl Trainer for BpeTrainer {
|
|||||||
type Model = BPE;
|
type Model = BPE;
|
||||||
|
|
||||||
/// Train a BPE model
|
/// Train a BPE model
|
||||||
fn train(&self, word_counts: HashMap<String, u32>, model: &mut BPE) -> Result<Vec<AddedToken>> {
|
fn train(&self, model: &mut BPE) -> Result<Vec<AddedToken>> {
|
||||||
self.train(word_counts, model)
|
self.do_train(&self.words, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Whether we should show progress
|
/// Whether we should show progress
|
||||||
fn should_show_progress(&self) -> bool {
|
fn should_show_progress(&self) -> bool {
|
||||||
self.show_progress
|
self.show_progress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = S> + Send,
|
||||||
|
S: AsRef<str> + Send,
|
||||||
|
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||||
|
{
|
||||||
|
let words: Result<HashMap<String, u32>> = iterator
|
||||||
|
.maybe_par_bridge()
|
||||||
|
.map(|sequence| {
|
||||||
|
let words = process(sequence.as_ref())?;
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
for word in words {
|
||||||
|
map.entry(word).and_modify(|c| *c += 1).or_insert(1);
|
||||||
|
}
|
||||||
|
Ok(map)
|
||||||
|
})
|
||||||
|
.reduce(
|
||||||
|
|| Ok(HashMap::new()),
|
||||||
|
|acc, ws| {
|
||||||
|
let mut acc = acc?;
|
||||||
|
for (k, v) in ws? {
|
||||||
|
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
||||||
|
}
|
||||||
|
Ok(acc)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
self.words = words?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -624,7 +658,7 @@ mod tests {
|
|||||||
.min_frequency(2)
|
.min_frequency(2)
|
||||||
.build();
|
.build();
|
||||||
let mut model = BPE::default();
|
let mut model = BPE::default();
|
||||||
trainer.train(word_counts, &mut model).unwrap();
|
trainer.do_train(&word_counts, &mut model).unwrap();
|
||||||
|
|
||||||
// Vocab should contain all of the characters from the `word_counts` mapping
|
// Vocab should contain all of the characters from the `word_counts` mapping
|
||||||
// as well as three merges: 're', 'are', and 'is'.
|
// as well as three merges: 're', 'are', and 'is'.
|
||||||
|
@ -145,37 +145,38 @@ impl Trainer for TrainerWrapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn train(
|
fn train(&self, model: &mut ModelWrapper) -> Result<Vec<AddedToken>> {
|
||||||
&self,
|
|
||||||
words: HashMap<String, u32>,
|
|
||||||
model: &mut ModelWrapper,
|
|
||||||
) -> Result<Vec<AddedToken>> {
|
|
||||||
match self {
|
match self {
|
||||||
TrainerWrapper::BpeTrainer(t) => match model {
|
TrainerWrapper::BpeTrainer(t) => match model {
|
||||||
ModelWrapper::BPE(bpe) => t.train(words, bpe),
|
ModelWrapper::BPE(bpe) => t.train(bpe),
|
||||||
_ => Err("BpeTrainer can only train a BPE".into()),
|
_ => Err("BpeTrainer can only train a BPE".into()),
|
||||||
},
|
},
|
||||||
TrainerWrapper::WordPieceTrainer(t) => match model {
|
TrainerWrapper::WordPieceTrainer(t) => match model {
|
||||||
ModelWrapper::WordPiece(wp) => t.train(words, wp),
|
ModelWrapper::WordPiece(wp) => t.train(wp),
|
||||||
_ => Err("WordPieceTrainer can only train a WordPiece".into()),
|
_ => Err("WordPieceTrainer can only train a WordPiece".into()),
|
||||||
},
|
},
|
||||||
TrainerWrapper::WordLevelTrainer(t) => match model {
|
TrainerWrapper::WordLevelTrainer(t) => match model {
|
||||||
ModelWrapper::WordLevel(wl) => t.train(words, wl),
|
ModelWrapper::WordLevel(wl) => t.train(wl),
|
||||||
_ => Err("WordLevelTrainer can only train a WordLevel".into()),
|
_ => Err("WordLevelTrainer can only train a WordLevel".into()),
|
||||||
},
|
},
|
||||||
TrainerWrapper::UnigramTrainer(t) => match model {
|
TrainerWrapper::UnigramTrainer(t) => match model {
|
||||||
ModelWrapper::Unigram(u) => t.train(words, u),
|
ModelWrapper::Unigram(u) => t.train(u),
|
||||||
_ => Err("UnigramTrainer can only train a Unigram".into()),
|
_ => Err("UnigramTrainer can only train a Unigram".into()),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = S> + Send,
|
||||||
|
S: AsRef<str> + Send,
|
||||||
|
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||||
|
{
|
||||||
match self {
|
match self {
|
||||||
TrainerWrapper::BpeTrainer(bpe) => bpe.process_tokens(words, tokens),
|
TrainerWrapper::BpeTrainer(bpe) => bpe.feed(iterator, process),
|
||||||
TrainerWrapper::WordPieceTrainer(wpt) => wpt.process_tokens(words, tokens),
|
TrainerWrapper::WordPieceTrainer(wpt) => wpt.feed(iterator, process),
|
||||||
TrainerWrapper::WordLevelTrainer(wpt) => wpt.process_tokens(words, tokens),
|
TrainerWrapper::WordLevelTrainer(wpt) => wpt.feed(iterator, process),
|
||||||
TrainerWrapper::UnigramTrainer(wpt) => wpt.process_tokens(words, tokens),
|
TrainerWrapper::UnigramTrainer(wpt) => wpt.feed(iterator, process),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -194,7 +195,7 @@ mod tests {
|
|||||||
let trainer = TrainerWrapper::BpeTrainer(BpeTrainer::default());
|
let trainer = TrainerWrapper::BpeTrainer(BpeTrainer::default());
|
||||||
let mut model = ModelWrapper::Unigram(Unigram::default());
|
let mut model = ModelWrapper::Unigram(Unigram::default());
|
||||||
|
|
||||||
let result = trainer.train(HashMap::new(), &mut model);
|
let result = trainer.train(&mut model);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use crate::models::unigram::{lattice::Lattice, model::Unigram};
|
use crate::models::unigram::{lattice::Lattice, model::Unigram};
|
||||||
use crate::tokenizer::{AddedToken, Result, Trainer};
|
use crate::tokenizer::{AddedToken, Result, Trainer};
|
||||||
|
use crate::utils::parallelism::*;
|
||||||
use crate::utils::progress::{ProgressBar, ProgressStyle};
|
use crate::utils::progress::{ProgressBar, ProgressStyle};
|
||||||
use log::debug;
|
use log::debug;
|
||||||
use std::cmp::Reverse;
|
use std::cmp::Reverse;
|
||||||
@ -58,7 +59,9 @@ pub struct UnigramTrainer {
|
|||||||
#[builder(default = "16")]
|
#[builder(default = "16")]
|
||||||
pub max_piece_length: usize,
|
pub max_piece_length: usize,
|
||||||
#[builder(default = "1_000_000")]
|
#[builder(default = "1_000_000")]
|
||||||
pub seed_size: usize,
|
seed_size: usize,
|
||||||
|
#[builder(default = "HashMap::new()")]
|
||||||
|
words: HashMap<String, u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for UnigramTrainer {
|
impl Default for UnigramTrainer {
|
||||||
@ -451,7 +454,11 @@ impl UnigramTrainer {
|
|||||||
.collect();
|
.collect();
|
||||||
new_pieces
|
new_pieces
|
||||||
}
|
}
|
||||||
pub fn _train(&self, sentences: Vec<Sentence>, model: &mut Unigram) -> Result<Vec<AddedToken>> {
|
pub fn do_train(
|
||||||
|
&self,
|
||||||
|
sentences: Vec<Sentence>,
|
||||||
|
model: &mut Unigram,
|
||||||
|
) -> Result<Vec<AddedToken>> {
|
||||||
let progress = self.setup_progress();
|
let progress = self.setup_progress();
|
||||||
//
|
//
|
||||||
// 1. Compute frequent substrings
|
// 1. Compute frequent substrings
|
||||||
@ -533,19 +540,46 @@ impl Trainer for UnigramTrainer {
|
|||||||
type Model = Unigram;
|
type Model = Unigram;
|
||||||
|
|
||||||
/// Train a Unigram model
|
/// Train a Unigram model
|
||||||
fn train(
|
fn train(&self, model: &mut Unigram) -> Result<Vec<AddedToken>> {
|
||||||
&self,
|
let sentences: Vec<_> = self.words.iter().map(|(s, i)| (s.to_owned(), *i)).collect();
|
||||||
word_counts: HashMap<String, u32>,
|
self.do_train(sentences, model)
|
||||||
model: &mut Unigram,
|
|
||||||
) -> Result<Vec<AddedToken>> {
|
|
||||||
let sentences: Vec<_> = word_counts.into_iter().collect();
|
|
||||||
self._train(sentences, model)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Whether we should show progress
|
/// Whether we should show progress
|
||||||
fn should_show_progress(&self) -> bool {
|
fn should_show_progress(&self) -> bool {
|
||||||
self.show_progress
|
self.show_progress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = S> + Send,
|
||||||
|
S: AsRef<str> + Send,
|
||||||
|
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||||
|
{
|
||||||
|
let words: Result<HashMap<String, u32>> = iterator
|
||||||
|
.maybe_par_bridge()
|
||||||
|
.map(|sequence| {
|
||||||
|
let words = process(sequence.as_ref())?;
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
for word in words {
|
||||||
|
map.entry(word).and_modify(|c| *c += 1).or_insert(1);
|
||||||
|
}
|
||||||
|
Ok(map)
|
||||||
|
})
|
||||||
|
.reduce(
|
||||||
|
|| Ok(HashMap::new()),
|
||||||
|
|acc, ws| {
|
||||||
|
let mut acc = acc?;
|
||||||
|
for (k, v) in ws? {
|
||||||
|
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
||||||
|
}
|
||||||
|
Ok(acc)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
self.words = words?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -640,10 +674,7 @@ mod tests {
|
|||||||
|
|
||||||
let mut unigram = Unigram::default();
|
let mut unigram = Unigram::default();
|
||||||
trainer
|
trainer
|
||||||
.train(
|
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
|
||||||
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
|
|
||||||
&mut unigram,
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut pieces = unigram.iter();
|
let mut pieces = unigram.iter();
|
||||||
@ -665,10 +696,7 @@ mod tests {
|
|||||||
|
|
||||||
let mut unigram = Unigram::default();
|
let mut unigram = Unigram::default();
|
||||||
trainer
|
trainer
|
||||||
.train(
|
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
|
||||||
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
|
|
||||||
&mut unigram,
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut pieces = unigram.iter();
|
let mut pieces = unigram.iter();
|
||||||
@ -684,10 +712,7 @@ mod tests {
|
|||||||
|
|
||||||
let mut unigram = Unigram::default();
|
let mut unigram = Unigram::default();
|
||||||
trainer
|
trainer
|
||||||
.train(
|
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
|
||||||
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
|
|
||||||
&mut unigram,
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut pieces = unigram.iter();
|
let mut pieces = unigram.iter();
|
||||||
@ -707,10 +732,7 @@ mod tests {
|
|||||||
|
|
||||||
let mut unigram = Unigram::default();
|
let mut unigram = Unigram::default();
|
||||||
trainer
|
trainer
|
||||||
.train(
|
.do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
|
||||||
HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]),
|
|
||||||
&mut unigram,
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut pieces = unigram.iter();
|
let mut pieces = unigram.iter();
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
use super::WordLevel;
|
use super::WordLevel;
|
||||||
|
use crate::utils::parallelism::*;
|
||||||
use crate::{AddedToken, Result, Trainer};
|
use crate::{AddedToken, Result, Trainer};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -17,6 +18,9 @@ pub struct WordLevelTrainer {
|
|||||||
/// A list of special tokens that the model should know of
|
/// A list of special tokens that the model should know of
|
||||||
#[builder(default)]
|
#[builder(default)]
|
||||||
pub special_tokens: Vec<AddedToken>,
|
pub special_tokens: Vec<AddedToken>,
|
||||||
|
|
||||||
|
#[builder(default, private)]
|
||||||
|
words: HashMap<String, u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for WordLevelTrainer {
|
impl Default for WordLevelTrainer {
|
||||||
@ -26,6 +30,7 @@ impl Default for WordLevelTrainer {
|
|||||||
vocab_size: 30_000,
|
vocab_size: 30_000,
|
||||||
show_progress: true,
|
show_progress: true,
|
||||||
special_tokens: vec![],
|
special_tokens: vec![],
|
||||||
|
words: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -35,12 +40,12 @@ impl WordLevelTrainer {
|
|||||||
WordLevelTrainerBuilder::default()
|
WordLevelTrainerBuilder::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn train(
|
fn do_train(
|
||||||
&self,
|
&self,
|
||||||
word_counts: HashMap<String, u32>,
|
word_counts: &HashMap<String, u32>,
|
||||||
model: &mut WordLevel,
|
model: &mut WordLevel,
|
||||||
) -> Result<Vec<AddedToken>> {
|
) -> Result<Vec<AddedToken>> {
|
||||||
let mut ordered_counts = word_counts.into_iter().collect::<Vec<_>>();
|
let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();
|
||||||
ordered_counts.sort_by_key(|(_, n)| std::cmp::Reverse(*n));
|
ordered_counts.sort_by_key(|(_, n)| std::cmp::Reverse(*n));
|
||||||
let word_level = WordLevel::builder()
|
let word_level = WordLevel::builder()
|
||||||
.vocab(
|
.vocab(
|
||||||
@ -50,8 +55,8 @@ impl WordLevelTrainer {
|
|||||||
.chain(
|
.chain(
|
||||||
ordered_counts
|
ordered_counts
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter(|(_, n)| *n >= self.min_frequency)
|
.filter(|(_, n)| **n >= self.min_frequency)
|
||||||
.map(|(w, _)| w),
|
.map(|(w, _)| w.to_owned()),
|
||||||
)
|
)
|
||||||
.take(self.vocab_size)
|
.take(self.vocab_size)
|
||||||
.enumerate()
|
.enumerate()
|
||||||
@ -72,18 +77,45 @@ impl Trainer for WordLevelTrainer {
|
|||||||
type Model = WordLevel;
|
type Model = WordLevel;
|
||||||
|
|
||||||
/// Train a WordLevel model
|
/// Train a WordLevel model
|
||||||
fn train(
|
fn train(&self, model: &mut WordLevel) -> Result<Vec<AddedToken>> {
|
||||||
&self,
|
self.do_train(&self.words, model)
|
||||||
word_counts: HashMap<String, u32>,
|
|
||||||
model: &mut WordLevel,
|
|
||||||
) -> Result<Vec<AddedToken>> {
|
|
||||||
self.train(word_counts, model)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Whether we should show progress
|
/// Whether we should show progress
|
||||||
fn should_show_progress(&self) -> bool {
|
fn should_show_progress(&self) -> bool {
|
||||||
self.show_progress
|
self.show_progress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = S> + Send,
|
||||||
|
S: AsRef<str> + Send,
|
||||||
|
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||||
|
{
|
||||||
|
let words: Result<HashMap<String, u32>> = iterator
|
||||||
|
.maybe_par_bridge()
|
||||||
|
.map(|sequence| {
|
||||||
|
let words = process(sequence.as_ref())?;
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
for word in words {
|
||||||
|
map.entry(word).and_modify(|c| *c += 1).or_insert(1);
|
||||||
|
}
|
||||||
|
Ok(map)
|
||||||
|
})
|
||||||
|
.reduce(
|
||||||
|
|| Ok(HashMap::new()),
|
||||||
|
|acc, ws| {
|
||||||
|
let mut acc = acc?;
|
||||||
|
for (k, v) in ws? {
|
||||||
|
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
||||||
|
}
|
||||||
|
Ok(acc)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
self.words = words?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -108,7 +140,7 @@ mod tests {
|
|||||||
trainer.vocab_size = 5;
|
trainer.vocab_size = 5;
|
||||||
|
|
||||||
let mut model = WordLevel::default();
|
let mut model = WordLevel::default();
|
||||||
trainer.train(word_counts.clone(), &mut model).unwrap();
|
trainer.do_train(&word_counts, &mut model).unwrap();
|
||||||
let expected_vocab: HashMap<String, u32> = [
|
let expected_vocab: HashMap<String, u32> = [
|
||||||
("the".into(), 0),
|
("the".into(), 0),
|
||||||
("are".into(), 1),
|
("are".into(), 1),
|
||||||
@ -124,7 +156,7 @@ mod tests {
|
|||||||
// If we specify a min_frequency
|
// If we specify a min_frequency
|
||||||
trainer.min_frequency = 15;
|
trainer.min_frequency = 15;
|
||||||
let mut model = WordLevel::default();
|
let mut model = WordLevel::default();
|
||||||
trainer.train(word_counts, &mut model).unwrap();
|
trainer.do_train(&word_counts, &mut model).unwrap();
|
||||||
let expected_vocab: HashMap<String, u32> = [
|
let expected_vocab: HashMap<String, u32> = [
|
||||||
("the".into(), 0),
|
("the".into(), 0),
|
||||||
("are".into(), 1),
|
("are".into(), 1),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use super::WordPiece;
|
use super::WordPiece;
|
||||||
use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder, BPE};
|
use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder, BPE};
|
||||||
use crate::tokenizer::{AddedToken, Result, Trainer};
|
use crate::tokenizer::{AddedToken, Result, Trainer};
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::HashSet;
|
||||||
|
|
||||||
/// A `WordPieceTrainerBuilder` can be used to create a `WordPieceTrainer` with a custom
|
/// A `WordPieceTrainerBuilder` can be used to create a `WordPieceTrainer` with a custom
|
||||||
/// configuration.
|
/// configuration.
|
||||||
@ -153,13 +153,9 @@ impl WordPieceTrainer {
|
|||||||
WordPieceTrainerBuilder::default()
|
WordPieceTrainerBuilder::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn train(
|
pub fn train(&self, model: &mut WordPiece) -> Result<Vec<AddedToken>> {
|
||||||
&self,
|
|
||||||
word_counts: HashMap<String, u32>,
|
|
||||||
model: &mut WordPiece,
|
|
||||||
) -> Result<Vec<AddedToken>> {
|
|
||||||
let mut bpe = BPE::default();
|
let mut bpe = BPE::default();
|
||||||
let special_tokens = self.bpe_trainer.train(word_counts, &mut bpe)?;
|
let special_tokens = self.bpe_trainer.train(&mut bpe)?;
|
||||||
let new_wordpiece = WordPiece::from_bpe(&bpe);
|
let new_wordpiece = WordPiece::from_bpe(&bpe);
|
||||||
|
|
||||||
// Transfer the vocab
|
// Transfer the vocab
|
||||||
@ -175,19 +171,20 @@ impl WordPieceTrainer {
|
|||||||
impl Trainer for WordPieceTrainer {
|
impl Trainer for WordPieceTrainer {
|
||||||
type Model = WordPiece;
|
type Model = WordPiece;
|
||||||
|
|
||||||
fn train(
|
fn train(&self, model: &mut WordPiece) -> Result<Vec<AddedToken>> {
|
||||||
&self,
|
self.train(model)
|
||||||
word_counts: HashMap<String, u32>,
|
|
||||||
model: &mut WordPiece,
|
|
||||||
) -> Result<Vec<AddedToken>> {
|
|
||||||
self.train(word_counts, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
|
||||||
self.bpe_trainer.process_tokens(&mut words, tokens)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_show_progress(&self) -> bool {
|
fn should_show_progress(&self) -> bool {
|
||||||
self.bpe_trainer.should_show_progress()
|
self.bpe_trainer.should_show_progress()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = S> + Send,
|
||||||
|
S: AsRef<str> + Send,
|
||||||
|
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||||
|
{
|
||||||
|
self.bpe_trainer.feed(iterator, process)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -532,11 +532,15 @@ mod tests {
|
|||||||
fn should_show_progress(&self) -> bool {
|
fn should_show_progress(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
fn train(
|
fn train(&self, _model: &mut ModelMock) -> Result<Vec<AddedToken>> {
|
||||||
&self,
|
unimplemented!()
|
||||||
_words: HashMap<String, u32>,
|
}
|
||||||
_model: &mut ModelMock,
|
fn feed<I, S, F>(&mut self, _iterator: I, _process: F) -> Result<()>
|
||||||
) -> Result<Vec<AddedToken>> {
|
where
|
||||||
|
I: Iterator<Item = S> + Send,
|
||||||
|
S: AsRef<str> + Send,
|
||||||
|
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||||
|
{
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -131,20 +131,13 @@ pub trait Trainer {
|
|||||||
fn should_show_progress(&self) -> bool;
|
fn should_show_progress(&self) -> bool;
|
||||||
/// The actual training method. This will return a new trained Model as well as a list
|
/// The actual training method. This will return a new trained Model as well as a list
|
||||||
/// of `special_tokens` to be added directly to the tokenizer along with the model.
|
/// of `special_tokens` to be added directly to the tokenizer along with the model.
|
||||||
fn train(
|
fn train(&self, model: &mut Self::Model) -> Result<Vec<AddedToken>>;
|
||||||
&self,
|
/// Process an iterator of sequences already pre-processed by the Tokenizer
|
||||||
words: HashMap<String, u32>,
|
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||||
model: &mut Self::Model,
|
where
|
||||||
) -> Result<Vec<AddedToken>>;
|
I: Iterator<Item = S> + Send,
|
||||||
/// Process a bunch of token, counting them as relevant.
|
S: AsRef<str> + Send,
|
||||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
F: Fn(&str) -> Result<Vec<String>> + Sync;
|
||||||
for token in tokens {
|
|
||||||
words
|
|
||||||
.entry(token.clone())
|
|
||||||
.and_modify(|c| *c += 1)
|
|
||||||
.or_insert(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
@ -969,99 +962,74 @@ where
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Train a model and replace our current Model, using the given Trainer
|
pub fn train_from_files<T>(&mut self, trainer: &mut T, files: Vec<String>) -> Result<&mut Self>
|
||||||
fn word_count<MN, T>(&self, trainer: &T, files: Vec<String>) -> Result<HashMap<String, u32>>
|
|
||||||
where
|
where
|
||||||
T: Trainer<Model = MN> + Sync,
|
T: Trainer<Model = M> + Sync,
|
||||||
MN: Model,
|
|
||||||
{
|
{
|
||||||
let max_read = 1_000_000;
|
let max_read = 1_000_000;
|
||||||
let mut len = 0;
|
use crate::utils::iter::ResultShunt;
|
||||||
for file in files.iter() {
|
ResultShunt::process(
|
||||||
len += File::open(file)
|
files.into_iter().flat_map(|filename| {
|
||||||
.and_then(|f| f.metadata())
|
match File::open(filename) {
|
||||||
.map(|m| m.len())?;
|
Ok(file) => {
|
||||||
}
|
let file = BufReader::with_capacity(max_read, file);
|
||||||
|
// We read new lines using this API instead of the Lines Iterator
|
||||||
|
// on purpose. We want to keep the `\n` and potential `\r` between each lines
|
||||||
|
// We use an iterator to be able to chain with par_bridge.
|
||||||
|
itertools::Either::Left(file.lines_with_ending())
|
||||||
|
}
|
||||||
|
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|iter| self.train(trainer, iter).map(|_| {}),
|
||||||
|
)??;
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Train a model and replace our current Model, using the given Trainer
|
||||||
|
pub fn train<T, I, S>(&mut self, trainer: &mut T, sequences: I) -> Result<&mut Self>
|
||||||
|
where
|
||||||
|
T: Trainer<Model = M> + Sync,
|
||||||
|
I: Iterator<Item = S> + Send,
|
||||||
|
S: AsRef<str> + Send,
|
||||||
|
{
|
||||||
|
let (lower, upper) = sequences.size_hint();
|
||||||
|
let len = upper.unwrap_or(lower) as u64;
|
||||||
let progress = if trainer.should_show_progress() {
|
let progress = if trainer.should_show_progress() {
|
||||||
let progress = ProgressBar::new(len);
|
let progress = ProgressBar::new(len);
|
||||||
progress.set_style(
|
progress.set_style(
|
||||||
ProgressStyle::default_bar()
|
ProgressStyle::default_bar()
|
||||||
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {percent:>19!}"),
|
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
|
||||||
);
|
);
|
||||||
progress.set_message(&format!("Reading files ({:.2} Mo)", len / 1_000_000));
|
progress.set_message("Pre-processing sequences");
|
||||||
progress.set_draw_delta(len / 100); // Redraw only every 2%
|
progress.set_draw_delta(len / 100); // Redraw only every 2%
|
||||||
Some(progress)
|
Some(progress)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let words = files
|
|
||||||
.into_iter()
|
|
||||||
.map(|filename| -> Result<HashMap<String, u32>> {
|
|
||||||
let file = File::open(filename)?;
|
|
||||||
let file = BufReader::with_capacity(max_read, file);
|
|
||||||
// We read new lines using this API instead of the Lines Iterator
|
|
||||||
// on purpose. We want to keep the `\n` and potential `\r` between each lines
|
|
||||||
// We use an iterator to be able to chain with par_bridge.
|
|
||||||
file.lines_with_ending()
|
|
||||||
.maybe_par_bridge()
|
|
||||||
.map_with(
|
|
||||||
&progress,
|
|
||||||
|progress, line| -> Result<HashMap<String, u32>> {
|
|
||||||
let newline = line?;
|
|
||||||
let b = newline.len();
|
|
||||||
let mut words = HashMap::new();
|
|
||||||
let normalized = self.do_normalize(newline)?;
|
|
||||||
let pre_tokenized = self.do_pre_tokenize(normalized)?;
|
|
||||||
trainer.process_tokens(
|
|
||||||
&mut words,
|
|
||||||
pre_tokenized
|
|
||||||
.get_splits(OffsetReferential::Original, OffsetType::Byte)
|
|
||||||
.into_iter()
|
|
||||||
.map(|(s, _, _)| s.to_owned())
|
|
||||||
.collect(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(pbar) = progress {
|
trainer.feed(
|
||||||
pbar.inc(b as u64);
|
sequences.map(|s| {
|
||||||
}
|
// if let Some(progress) = &progress {
|
||||||
Ok(words)
|
// progress.inc(1)
|
||||||
},
|
// }
|
||||||
)
|
s
|
||||||
.reduce(
|
}),
|
||||||
|| Ok(HashMap::new()),
|
|seq| {
|
||||||
|acc, ws| {
|
let normalized = self.do_normalize(seq.as_ref())?;
|
||||||
let mut acc = acc?;
|
let pre_tokenized = self.do_pre_tokenize(normalized)?;
|
||||||
for (k, v) in ws? {
|
Ok(pre_tokenized
|
||||||
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
.get_splits(OffsetReferential::Original, OffsetType::Byte)
|
||||||
}
|
.into_iter()
|
||||||
Ok(acc)
|
.map(|(s, _, _)| s.to_owned())
|
||||||
},
|
.collect())
|
||||||
)
|
},
|
||||||
})
|
)?;
|
||||||
.try_fold(
|
|
||||||
HashMap::new(),
|
|
||||||
|mut acc, ws| -> Result<HashMap<String, u32>> {
|
|
||||||
for (k, v) in ws? {
|
|
||||||
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
|
||||||
}
|
|
||||||
Ok(acc)
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
if let Some(pbar) = progress {
|
if let Some(pbar) = progress {
|
||||||
pbar.finish();
|
pbar.finish();
|
||||||
}
|
}
|
||||||
Ok(words)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Train a model and replace our current Model, using the given Trainer
|
let special_tokens = trainer.train(&mut self.model)?;
|
||||||
pub fn train<T>(&mut self, trainer: &T, files: Vec<String>) -> Result<&mut Self>
|
|
||||||
where
|
|
||||||
T: Trainer<Model = M> + Sync,
|
|
||||||
{
|
|
||||||
let words = self.word_count(trainer, files)?;
|
|
||||||
|
|
||||||
let special_tokens = trainer.train(words, &mut self.model)?;
|
|
||||||
self.add_special_tokens(&special_tokens);
|
self.add_special_tokens(&special_tokens);
|
||||||
|
|
||||||
Ok(self)
|
Ok(self)
|
||||||
|
Reference in New Issue
Block a user