Improve progress tracking while training

This commit is contained in:
Anthony MOI
2020-11-25 15:55:58 -05:00
committed by Anthony MOI
parent 75deaecdd0
commit c36ac0bfdf
3 changed files with 158 additions and 45 deletions

View File

@ -12,7 +12,6 @@ use tk::tokenizer::{
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
TruncationParams, TruncationStrategy,
};
use tk::utils::iter::ResultShunt;
use tokenizers as tk;
use super::decoders::PyDecoder;
@ -1085,51 +1084,41 @@ impl PyTokenizer {
})
}
#[args(trainer = "None")]
#[args(trainer = "None", length = "None")]
fn train_from_iterator(
&mut self,
iterator: &PyAny,
trainer: Option<&mut PyTrainer>,
length: Option<usize>,
) -> PyResult<()> {
use crate::utils::PySendIterator;
let mut trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
let (send, recv) = std::sync::mpsc::sync_channel(256);
let mut sender = Some(send);
let iterator: PyIterator = iterator.iter()?;
crossbeam::thread::scope(|s| {
let _train_handle = s.spawn(|_| {
self.tokenizer
.train(&mut trainer, recv.into_iter())
.map(|_| {})
});
let py_send = PySendIterator::new(
// Each element of the iterator can either be:
// - An iterator, to allow batching
// - A string
iterator.iter()?.flat_map(|seq| match seq {
Ok(s) => {
if let Ok(iter) = s.iter() {
itertools::Either::Left(iter.map(|i| i?.extract::<&str>()))
} else {
itertools::Either::Right(std::iter::once(s.extract::<&str>()))
}
}
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}),
length,
);
ResultShunt::process(
// Each element of the iterator can either be:
// - An iterator, to allow batching
// - A string
iterator.flat_map(|seq| match seq {
Ok(s) => {
if let Ok(iter) = s.iter() {
itertools::Either::Left(iter.map(|i| i?.extract::<&str>()))
} else {
itertools::Either::Right(std::iter::once(s.extract::<&str>()))
}
}
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}),
|iter| {
if let Some(send) = sender.take() {
for seq in iter {
send.send(seq)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
}
}
Ok(())
},
)?
py_send.execute(|iter| {
self.tokenizer
.train(&mut trainer, iter)
.map(|_| {})
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
})
.unwrap()
}
/// Apply all the post-processing steps to the given encodings.