diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 727ebd40..13ab8440 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -169,7 +169,7 @@ impl BpeTrainer { let p = ProgressBar::new(0); p.set_style( ProgressStyle::default_bar() - .template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"), + .template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"), ); Some(p) } else { @@ -483,4 +483,9 @@ impl Trainer for BpeTrainer { .or_insert(1); } } + + /// Whether we should show progress + fn should_show_progress(&self) -> bool { + self.show_progress + } } diff --git a/tokenizers/src/models/wordpiece/trainer.rs b/tokenizers/src/models/wordpiece/trainer.rs index 2f5d163c..08a873a3 100644 --- a/tokenizers/src/models/wordpiece/trainer.rs +++ b/tokenizers/src/models/wordpiece/trainer.rs @@ -104,4 +104,8 @@ impl Trainer for WordPieceTrainer { fn process_tokens(&self, mut words: &mut HashMap, tokens: Vec) { self.bpe_trainer.process_tokens(&mut words, tokens) } + + fn should_show_progress(&self) -> bool { + self.bpe_trainer.should_show_progress() + } } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 1e7bb159..7166e14a 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -13,6 +13,7 @@ pub use crate::utils::{ pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy, }; +use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; use rayon::prelude::*; use std::{ collections::HashMap, @@ -61,6 +62,7 @@ pub trait Decoder { /// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences /// and it returns a `Model` when done. pub trait Trainer: Sync { + fn should_show_progress(&self) -> bool; fn train(&self, words: HashMap) -> Result>; fn process_tokens(&self, words: &mut HashMap, tokens: Vec); } @@ -405,13 +407,30 @@ impl Tokenizer { /// Train a model and replace our current Model, using the given Trainer #[allow(clippy::borrowed_box)] pub fn train(&mut self, trainer: &Box, files: Vec) -> Result<()> { - let results = files - .par_iter() - .map(|file| -> Result> { - let mut words = HashMap::new(); + let progress = MultiProgress::new(); + let style = ProgressStyle::default_bar() + .template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"); + let jobs = files + .into_iter() + .map(|filename| { + let file = File::open(filename.clone())?; + let len = file.metadata().map(|c| c.len()).unwrap_or(0); + let pbar = progress.add(ProgressBar::new(len)); + if !trainer.should_show_progress() { + pbar.set_draw_target(ProgressDrawTarget::hidden()); + } + pbar.set_style(style.clone()); + pbar.set_message(&filename); + Ok((file, pbar)) + }) + .collect::>>()?; - let file: std::fs::File = File::open(file)?; - let mut file = BufReader::new(file); + let handle = std::thread::spawn(move || progress.join().unwrap()); + let results = jobs + .par_iter() + .map(|(file, pbar)| -> Result> { + let mut words = HashMap::new(); + let mut file = BufReader::new(pbar.wrap_read(file)); let mut buf = String::new(); loop { @@ -431,9 +450,11 @@ impl Tokenizer { } } + pbar.finish(); Ok(words) }) .collect::>(); + handle.join().unwrap(); let mut words = HashMap::new(); for result in results {