Rust - Improve training progress for multiple files

This commit is contained in:
Anthony MOI
2020-03-03 11:04:24 -05:00
parent afe96da7e3
commit 4b596e19dd
3 changed files with 37 additions and 27 deletions

View File

@@ -1,3 +1,9 @@
# Not released yet
## Changes:
- Keep only one progress bar while reading files during training. This is better for use-cases with
a high number of files as it avoids having too many progress bar on screen.
# v0.6.0
## Changes:

View File

@@ -1,3 +1,9 @@
# Not released yet
## Changes:
- Keep only one progress bar while reading files during training. This is better for use-cases with
a high number of files as it avoids having too many progress bar on screen.
# v0.8.0
## Changes:

View File

@@ -13,7 +13,7 @@ pub use crate::utils::{
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
TruncationStrategy,
};
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::prelude::*;
use std::{
collections::HashMap,
@@ -417,31 +417,23 @@ impl Tokenizer {
/// Train a model and replace our current Model, using the given Trainer
#[allow(clippy::borrowed_box)]
pub fn train(&mut self, trainer: &Box<dyn Trainer>, files: Vec<String>) -> Result<()> {
let progress = MultiProgress::new();
let style = ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {bytes:<9!}/{total_bytes:>9!}");
let jobs = files
.into_par_iter()
.map(|filename| {
let file = File::open(filename.clone())?;
let len = file.metadata().map(|c| c.len()).unwrap_or(0);
let pbar = progress.add(ProgressBar::new(len));
if !trainer.should_show_progress() {
pbar.set_draw_target(ProgressDrawTarget::hidden());
}
pbar.set_style(style.clone());
pbar.set_message(&filename);
Ok((filename, pbar))
})
.collect::<Result<Vec<_>>>()?;
let progress = ProgressBar::new(100 * files.len() as u64);
progress.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {percent:>19!}"),
);
progress.set_message("Reading files");
let handle = std::thread::spawn(move || progress.join().unwrap());
let results = jobs
.par_iter()
.map(|(filename, pbar)| -> Result<HashMap<String, u32>> {
let results = files
.into_par_iter()
.map(|filename| -> Result<HashMap<String, u32>> {
let mut words = HashMap::new();
let file = File::open(filename.clone())?;
let mut file = BufReader::new(pbar.wrap_read(file));
let file = File::open(filename)?;
let len = file.metadata().map_or(0, |c| c.len());
let mut file = BufReader::new(file);
let mut prev_prog = 0;
let mut read = 0;
let mut curr_prog;
let mut buf = String::new();
loop {
@@ -450,22 +442,28 @@ impl Tokenizer {
// on purpose. We want to keep the `\n` and potential `\r` between each lines
match file.read_line(&mut buf)? {
0 => break,
_ => {
b => {
let normalized = self.normalize(&buf)?;
let pre_tokenized = self.pre_tokenize(normalized.get())?;
trainer.process_tokens(
&mut words,
pre_tokenized.into_iter().map(|(t, _)| t).collect(),
);
read += b as u64;
curr_prog = ((read as f64 / len as f64) * 100.0) as u64;
if curr_prog > prev_prog {
progress.inc(curr_prog - prev_prog);
prev_prog = curr_prog;
}
}
}
}
pbar.finish();
Ok(words)
})
.collect::<Vec<_>>();
handle.join().unwrap();
progress.finish();
let mut words = HashMap::new();
for result in results {