mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-08 13:48:19 +00:00
Rust - Improve training progress for multiple files
This commit is contained in:
@@ -1,3 +1,9 @@
|
|||||||
|
# Not released yet
|
||||||
|
|
||||||
|
## Changes:
|
||||||
|
- Keep only one progress bar while reading files during training. This is better for use-cases with
|
||||||
|
a high number of files as it avoids having too many progress bar on screen.
|
||||||
|
|
||||||
# v0.6.0
|
# v0.6.0
|
||||||
|
|
||||||
## Changes:
|
## Changes:
|
||||||
|
|||||||
@@ -1,3 +1,9 @@
|
|||||||
|
# Not released yet
|
||||||
|
|
||||||
|
## Changes:
|
||||||
|
- Keep only one progress bar while reading files during training. This is better for use-cases with
|
||||||
|
a high number of files as it avoids having too many progress bar on screen.
|
||||||
|
|
||||||
# v0.8.0
|
# v0.8.0
|
||||||
|
|
||||||
## Changes:
|
## Changes:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ pub use crate::utils::{
|
|||||||
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
|
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
|
||||||
TruncationStrategy,
|
TruncationStrategy,
|
||||||
};
|
};
|
||||||
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressStyle};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
@@ -417,31 +417,23 @@ impl Tokenizer {
|
|||||||
/// Train a model and replace our current Model, using the given Trainer
|
/// Train a model and replace our current Model, using the given Trainer
|
||||||
#[allow(clippy::borrowed_box)]
|
#[allow(clippy::borrowed_box)]
|
||||||
pub fn train(&mut self, trainer: &Box<dyn Trainer>, files: Vec<String>) -> Result<()> {
|
pub fn train(&mut self, trainer: &Box<dyn Trainer>, files: Vec<String>) -> Result<()> {
|
||||||
let progress = MultiProgress::new();
|
let progress = ProgressBar::new(100 * files.len() as u64);
|
||||||
let style = ProgressStyle::default_bar()
|
progress.set_style(
|
||||||
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {bytes:<9!}/{total_bytes:>9!}");
|
ProgressStyle::default_bar()
|
||||||
let jobs = files
|
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {percent:>19!}"),
|
||||||
.into_par_iter()
|
);
|
||||||
.map(|filename| {
|
progress.set_message("Reading files");
|
||||||
let file = File::open(filename.clone())?;
|
|
||||||
let len = file.metadata().map(|c| c.len()).unwrap_or(0);
|
|
||||||
let pbar = progress.add(ProgressBar::new(len));
|
|
||||||
if !trainer.should_show_progress() {
|
|
||||||
pbar.set_draw_target(ProgressDrawTarget::hidden());
|
|
||||||
}
|
|
||||||
pbar.set_style(style.clone());
|
|
||||||
pbar.set_message(&filename);
|
|
||||||
Ok((filename, pbar))
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
|
|
||||||
let handle = std::thread::spawn(move || progress.join().unwrap());
|
let results = files
|
||||||
let results = jobs
|
.into_par_iter()
|
||||||
.par_iter()
|
.map(|filename| -> Result<HashMap<String, u32>> {
|
||||||
.map(|(filename, pbar)| -> Result<HashMap<String, u32>> {
|
|
||||||
let mut words = HashMap::new();
|
let mut words = HashMap::new();
|
||||||
let file = File::open(filename.clone())?;
|
let file = File::open(filename)?;
|
||||||
let mut file = BufReader::new(pbar.wrap_read(file));
|
let len = file.metadata().map_or(0, |c| c.len());
|
||||||
|
let mut file = BufReader::new(file);
|
||||||
|
let mut prev_prog = 0;
|
||||||
|
let mut read = 0;
|
||||||
|
let mut curr_prog;
|
||||||
|
|
||||||
let mut buf = String::new();
|
let mut buf = String::new();
|
||||||
loop {
|
loop {
|
||||||
@@ -450,22 +442,28 @@ impl Tokenizer {
|
|||||||
// on purpose. We want to keep the `\n` and potential `\r` between each lines
|
// on purpose. We want to keep the `\n` and potential `\r` between each lines
|
||||||
match file.read_line(&mut buf)? {
|
match file.read_line(&mut buf)? {
|
||||||
0 => break,
|
0 => break,
|
||||||
_ => {
|
b => {
|
||||||
let normalized = self.normalize(&buf)?;
|
let normalized = self.normalize(&buf)?;
|
||||||
let pre_tokenized = self.pre_tokenize(normalized.get())?;
|
let pre_tokenized = self.pre_tokenize(normalized.get())?;
|
||||||
trainer.process_tokens(
|
trainer.process_tokens(
|
||||||
&mut words,
|
&mut words,
|
||||||
pre_tokenized.into_iter().map(|(t, _)| t).collect(),
|
pre_tokenized.into_iter().map(|(t, _)| t).collect(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
read += b as u64;
|
||||||
|
curr_prog = ((read as f64 / len as f64) * 100.0) as u64;
|
||||||
|
if curr_prog > prev_prog {
|
||||||
|
progress.inc(curr_prog - prev_prog);
|
||||||
|
prev_prog = curr_prog;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pbar.finish();
|
|
||||||
Ok(words)
|
Ok(words)
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
handle.join().unwrap();
|
progress.finish();
|
||||||
|
|
||||||
let mut words = HashMap::new();
|
let mut words = HashMap::new();
|
||||||
for result in results {
|
for result in results {
|
||||||
|
|||||||
Reference in New Issue
Block a user