mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Show progress while reading files during training
This commit is contained in:
@ -169,7 +169,7 @@ impl BpeTrainer {
|
|||||||
let p = ProgressBar::new(0);
|
let p = ProgressBar::new(0);
|
||||||
p.set_style(
|
p.set_style(
|
||||||
ProgressStyle::default_bar()
|
ProgressStyle::default_bar()
|
||||||
.template("[{elapsed_precise}] {msg:30} {bar:40} {pos:>7}/{len:7}"),
|
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
|
||||||
);
|
);
|
||||||
Some(p)
|
Some(p)
|
||||||
} else {
|
} else {
|
||||||
@ -483,4 +483,9 @@ impl Trainer for BpeTrainer {
|
|||||||
.or_insert(1);
|
.or_insert(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Whether we should show progress
|
||||||
|
fn should_show_progress(&self) -> bool {
|
||||||
|
self.show_progress
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -104,4 +104,8 @@ impl Trainer for WordPieceTrainer {
|
|||||||
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||||
self.bpe_trainer.process_tokens(&mut words, tokens)
|
self.bpe_trainer.process_tokens(&mut words, tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn should_show_progress(&self) -> bool {
|
||||||
|
self.bpe_trainer.should_show_progress()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ pub use crate::utils::{
|
|||||||
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
|
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
|
||||||
TruncationStrategy,
|
TruncationStrategy,
|
||||||
};
|
};
|
||||||
|
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
@ -61,6 +62,7 @@ pub trait Decoder {
|
|||||||
/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
|
/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
|
||||||
/// and it returns a `Model` when done.
|
/// and it returns a `Model` when done.
|
||||||
pub trait Trainer: Sync {
|
pub trait Trainer: Sync {
|
||||||
|
fn should_show_progress(&self) -> bool;
|
||||||
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>>;
|
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>>;
|
||||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
|
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
|
||||||
}
|
}
|
||||||
@ -405,13 +407,30 @@ impl Tokenizer {
|
|||||||
/// Train a model and replace our current Model, using the given Trainer
|
/// Train a model and replace our current Model, using the given Trainer
|
||||||
#[allow(clippy::borrowed_box)]
|
#[allow(clippy::borrowed_box)]
|
||||||
pub fn train(&mut self, trainer: &Box<dyn Trainer>, files: Vec<String>) -> Result<()> {
|
pub fn train(&mut self, trainer: &Box<dyn Trainer>, files: Vec<String>) -> Result<()> {
|
||||||
let results = files
|
let progress = MultiProgress::new();
|
||||||
.par_iter()
|
let style = ProgressStyle::default_bar()
|
||||||
.map(|file| -> Result<HashMap<String, u32>> {
|
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}");
|
||||||
let mut words = HashMap::new();
|
let jobs = files
|
||||||
|
.into_iter()
|
||||||
|
.map(|filename| {
|
||||||
|
let file = File::open(filename.clone())?;
|
||||||
|
let len = file.metadata().map(|c| c.len()).unwrap_or(0);
|
||||||
|
let pbar = progress.add(ProgressBar::new(len));
|
||||||
|
if !trainer.should_show_progress() {
|
||||||
|
pbar.set_draw_target(ProgressDrawTarget::hidden());
|
||||||
|
}
|
||||||
|
pbar.set_style(style.clone());
|
||||||
|
pbar.set_message(&filename);
|
||||||
|
Ok((file, pbar))
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let file: std::fs::File = File::open(file)?;
|
let handle = std::thread::spawn(move || progress.join().unwrap());
|
||||||
let mut file = BufReader::new(file);
|
let results = jobs
|
||||||
|
.par_iter()
|
||||||
|
.map(|(file, pbar)| -> Result<HashMap<String, u32>> {
|
||||||
|
let mut words = HashMap::new();
|
||||||
|
let mut file = BufReader::new(pbar.wrap_read(file));
|
||||||
|
|
||||||
let mut buf = String::new();
|
let mut buf = String::new();
|
||||||
loop {
|
loop {
|
||||||
@ -431,9 +450,11 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pbar.finish();
|
||||||
Ok(words)
|
Ok(words)
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
handle.join().unwrap();
|
||||||
|
|
||||||
let mut words = HashMap::new();
|
let mut words = HashMap::new();
|
||||||
for result in results {
|
for result in results {
|
||||||
|
Reference in New Issue
Block a user