Show progress while reading files during training

This commit is contained in:
Anthony MOI
2020-01-10 00:21:46 -05:00
parent 13f3fbed30
commit 1f16fcbe77
3 changed files with 37 additions and 7 deletions

View File

@ -169,7 +169,7 @@ impl BpeTrainer {
let p = ProgressBar::new(0);
p.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"),
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
);
Some(p)
} else {
@ -483,4 +483,9 @@ impl Trainer for BpeTrainer {
.or_insert(1);
}
}
/// Whether we should show progress
fn should_show_progress(&self) -> bool {
self.show_progress
}
}

View File

@ -104,4 +104,8 @@ impl Trainer for WordPieceTrainer {
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
self.bpe_trainer.process_tokens(&mut words, tokens)
}
fn should_show_progress(&self) -> bool {
self.bpe_trainer.should_show_progress()
}
}

View File

@ -13,6 +13,7 @@ pub use crate::utils::{
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
TruncationStrategy,
};
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
use rayon::prelude::*;
use std::{
collections::HashMap,
@ -61,6 +62,7 @@ pub trait Decoder {
/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
/// and it returns a `Model` when done.
pub trait Trainer: Sync {
fn should_show_progress(&self) -> bool;
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>>;
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
}
@ -405,13 +407,30 @@ impl Tokenizer {
/// Train a model and replace our current Model, using the given Trainer
#[allow(clippy::borrowed_box)]
pub fn train(&mut self, trainer: &Box<dyn Trainer>, files: Vec<String>) -> Result<()> {
let results = files
.par_iter()
.map(|file| -> Result<HashMap<String, u32>> {
let mut words = HashMap::new();
let progress = MultiProgress::new();
let style = ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}");
let jobs = files
.into_iter()
.map(|filename| {
let file = File::open(filename.clone())?;
let len = file.metadata().map(|c| c.len()).unwrap_or(0);
let pbar = progress.add(ProgressBar::new(len));
if !trainer.should_show_progress() {
pbar.set_draw_target(ProgressDrawTarget::hidden());
}
pbar.set_style(style.clone());
pbar.set_message(&filename);
Ok((file, pbar))
})
.collect::<Result<Vec<_>>>()?;
let file: std::fs::File = File::open(file)?;
let mut file = BufReader::new(file);
let handle = std::thread::spawn(move || progress.join().unwrap());
let results = jobs
.par_iter()
.map(|(file, pbar)| -> Result<HashMap<String, u32>> {
let mut words = HashMap::new();
let mut file = BufReader::new(pbar.wrap_read(file));
let mut buf = String::new();
loop {
@ -431,9 +450,11 @@ impl Tokenizer {
}
}
pbar.finish();
Ok(words)
})
.collect::<Vec<_>>();
handle.join().unwrap();
let mut words = HashMap::new();
for result in results {