mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Show progress while reading files during training
This commit is contained in:
@ -169,7 +169,7 @@ impl BpeTrainer {
|
||||
let p = ProgressBar::new(0);
|
||||
p.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"),
|
||||
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
|
||||
);
|
||||
Some(p)
|
||||
} else {
|
||||
@ -483,4 +483,9 @@ impl Trainer for BpeTrainer {
|
||||
.or_insert(1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether we should show progress
|
||||
fn should_show_progress(&self) -> bool {
|
||||
self.show_progress
|
||||
}
|
||||
}
|
||||
|
@ -104,4 +104,8 @@ impl Trainer for WordPieceTrainer {
|
||||
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
self.bpe_trainer.process_tokens(&mut words, tokens)
|
||||
}
|
||||
|
||||
fn should_show_progress(&self) -> bool {
|
||||
self.bpe_trainer.should_show_progress()
|
||||
}
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ pub use crate::utils::{
|
||||
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
|
||||
TruncationStrategy,
|
||||
};
|
||||
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
||||
use rayon::prelude::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@ -61,6 +62,7 @@ pub trait Decoder {
|
||||
/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
|
||||
/// and it returns a `Model` when done.
|
||||
pub trait Trainer: Sync {
|
||||
fn should_show_progress(&self) -> bool;
|
||||
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>>;
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
|
||||
}
|
||||
@ -405,13 +407,30 @@ impl Tokenizer {
|
||||
/// Train a model and replace our current Model, using the given Trainer
|
||||
#[allow(clippy::borrowed_box)]
|
||||
pub fn train(&mut self, trainer: &Box<dyn Trainer>, files: Vec<String>) -> Result<()> {
|
||||
let results = files
|
||||
.par_iter()
|
||||
.map(|file| -> Result<HashMap<String, u32>> {
|
||||
let mut words = HashMap::new();
|
||||
let progress = MultiProgress::new();
|
||||
let style = ProgressStyle::default_bar()
|
||||
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}");
|
||||
let jobs = files
|
||||
.into_iter()
|
||||
.map(|filename| {
|
||||
let file = File::open(filename.clone())?;
|
||||
let len = file.metadata().map(|c| c.len()).unwrap_or(0);
|
||||
let pbar = progress.add(ProgressBar::new(len));
|
||||
if !trainer.should_show_progress() {
|
||||
pbar.set_draw_target(ProgressDrawTarget::hidden());
|
||||
}
|
||||
pbar.set_style(style.clone());
|
||||
pbar.set_message(&filename);
|
||||
Ok((file, pbar))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let file: std::fs::File = File::open(file)?;
|
||||
let mut file = BufReader::new(file);
|
||||
let handle = std::thread::spawn(move || progress.join().unwrap());
|
||||
let results = jobs
|
||||
.par_iter()
|
||||
.map(|(file, pbar)| -> Result<HashMap<String, u32>> {
|
||||
let mut words = HashMap::new();
|
||||
let mut file = BufReader::new(pbar.wrap_read(file));
|
||||
|
||||
let mut buf = String::new();
|
||||
loop {
|
||||
@ -431,9 +450,11 @@ impl Tokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
pbar.finish();
|
||||
Ok(words)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
handle.join().unwrap();
|
||||
|
||||
let mut words = HashMap::new();
|
||||
for result in results {
|
||||
|
Reference in New Issue
Block a user