Improve progress tracking while training

This commit is contained in:
Anthony MOI
2020-11-25 15:55:58 -05:00
committed by Anthony MOI
parent 75deaecdd0
commit c36ac0bfdf
3 changed files with 158 additions and 45 deletions

View File

@ -12,7 +12,6 @@ use tk::tokenizer::{
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
TruncationParams, TruncationStrategy,
};
use tk::utils::iter::ResultShunt;
use tokenizers as tk;
use super::decoders::PyDecoder;
@ -1085,30 +1084,23 @@ impl PyTokenizer {
})
}
#[args(trainer = "None")]
#[args(trainer = "None", length = "None")]
fn train_from_iterator(
&mut self,
iterator: &PyAny,
trainer: Option<&mut PyTrainer>,
length: Option<usize>,
) -> PyResult<()> {
use crate::utils::PySendIterator;
let mut trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
let (send, recv) = std::sync::mpsc::sync_channel(256);
let mut sender = Some(send);
let iterator: PyIterator = iterator.iter()?;
crossbeam::thread::scope(|s| {
let _train_handle = s.spawn(|_| {
self.tokenizer
.train(&mut trainer, recv.into_iter())
.map(|_| {})
});
ResultShunt::process(
let py_send = PySendIterator::new(
// Each element of the iterator can either be:
// - An iterator, to allow batching
// - A string
iterator.flat_map(|seq| match seq {
iterator.iter()?.flat_map(|seq| match seq {
Ok(s) => {
if let Ok(iter) = s.iter() {
itertools::Either::Left(iter.map(|i| i?.extract::<&str>()))
@ -1118,18 +1110,15 @@ impl PyTokenizer {
}
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}),
|iter| {
if let Some(send) = sender.take() {
for seq in iter {
send.send(seq)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
}
}
Ok(())
},
)?
length,
);
py_send.execute(|iter| {
self.tokenizer
.train(&mut trainer, iter)
.map(|_| {})
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
})
.unwrap()
}
/// Apply all the post-processing steps to the given encodings.

View File

@ -12,6 +12,75 @@ pub use normalization::*;
pub use pretokenization::*;
pub use regex::*;
// PySendIterator
use std::sync::mpsc::{sync_channel, IntoIter};
use tk::utils::iter::ResultShunt;
pub struct MaybeSizedIterator<I> {
length: Option<usize>,
iter: I,
}
impl<I> Iterator for MaybeSizedIterator<I>
where
I: Iterator,
{
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.length.unwrap_or(0), None)
}
}
pub struct PySendIterator<I: Iterator> {
iter: I,
length: Option<usize>,
}
impl<I, T> PySendIterator<I>
where
I: Iterator<Item = PyResult<T>>,
T: Send,
{
pub fn new(iter: I, length: Option<usize>) -> Self {
PySendIterator { iter, length }
}
pub fn execute<F>(self, mut scope: F) -> PyResult<()>
where
F: FnMut(MaybeSizedIterator<IntoIter<T>>) -> PyResult<()> + Send + Sync,
{
let (send, recv) = sync_channel(256);
let mut sender = Some(send);
crossbeam::thread::scope(|s| {
let length = self.length;
s.spawn(move |_| {
scope(MaybeSizedIterator {
length,
iter: recv.into_iter(),
})
});
ResultShunt::process(self.iter, |iter| {
if let Some(send) = sender.take() {
for i in iter {
send.send(i)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
}
}
Ok(())
})?
})
.unwrap()
}
}
// PyChar
// This type is a temporary hack to accept `char` as argument
// To be removed once https://github.com/PyO3/pyo3/pull/1282 has been released

View File

@ -23,6 +23,7 @@ use serde::de::DeserializeOwned;
use serde::export::Formatter;
use serde::{Deserialize, Serialize};
use crate::utils::iter::ResultShunt;
use crate::utils::parallelism::*;
use crate::utils::progress::{ProgressBar, ProgressStyle};
@ -124,7 +125,7 @@ pub trait Decoder {
}
/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
/// and it returns a `Model` when done.
/// and then it can train the given `Model`.
pub trait Trainer {
type Model: Model + Sized;
/// Whether we should show progress during the training.
@ -132,7 +133,8 @@ pub trait Trainer {
/// The actual training method. This will return a new trained Model as well as a list
/// of `special_tokens` to be added directly to the tokenizer along with the model.
fn train(&self, model: &mut Self::Model) -> Result<Vec<AddedToken>>;
/// Process an iterator of sequences already pre-processed by the Tokenizer
/// Process an iterator of sequences, calling `process` for each of them in order to
/// pre-process the said sequence as relevant.
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
where
I: Iterator<Item = S> + Send,
@ -962,12 +964,20 @@ where
.collect()
}
/// Train our Model from files
pub fn train_from_files<T>(&mut self, trainer: &mut T, files: Vec<String>) -> Result<&mut Self>
where
T: Trainer<Model = M> + Sync,
{
let mut len = 0;
for file in files.iter() {
len += File::open(file)
.and_then(|f| f.metadata())
.map(|m| m.len())?;
}
let max_read = 1_000_000;
use crate::utils::iter::ResultShunt;
ResultShunt::process(
files.into_iter().flat_map(|filename| {
match File::open(filename) {
@ -981,12 +991,52 @@ where
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}
}),
|iter| self.train(trainer, iter).map(|_| {}),
|sequences| -> Result<()> {
let progress = if trainer.should_show_progress() {
let progress = ProgressBar::new(len);
progress.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {percent:>18!}%"),
);
progress
.set_message(&format!("Pre-processing files ({:.2} Mo)", len / 1_000_000));
progress.set_draw_delta(len / 100); // Redraw only every 2%
Some(progress)
} else {
None
};
trainer.feed(
sequences.map(|s| {
if let Some(progress) = &progress {
progress.inc(s.len() as u64)
}
s
}),
|seq| {
let normalized = self.do_normalize(seq.as_ref())?;
let pre_tokenized = self.do_pre_tokenize(normalized)?;
Ok(pre_tokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, _, _)| s.to_owned())
.collect())
},
)?;
if let Some(pbar) = progress {
pbar.finish();
}
let special_tokens = trainer.train(&mut self.model)?;
self.add_special_tokens(&special_tokens);
Ok(())
},
)??;
Ok(self)
}
/// Train a model and replace our current Model, using the given Trainer
/// Train our Model, using the given Trainer and iterator
pub fn train<T, I, S>(&mut self, trainer: &mut T, sequences: I) -> Result<&mut Self>
where
T: Trainer<Model = M> + Sync,
@ -1002,7 +1052,12 @@ where
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
);
progress.set_message("Pre-processing sequences");
if len > 0 {
progress.set_draw_delta(len / 100); // Redraw only every 2%
} else {
// Trying to have a good default to avoid progress tracking being the bottleneck
progress.set_draw_delta(1000);
}
Some(progress)
} else {
None
@ -1010,9 +1065,9 @@ where
trainer.feed(
sequences.map(|s| {
// if let Some(progress) = &progress {
// progress.inc(1)
// }
if let Some(progress) = &progress {
progress.inc(1)
}
s
}),
|seq| {