mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Improve progress tracking while training
This commit is contained in:
@ -12,7 +12,6 @@ use tk::tokenizer::{
|
|||||||
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
|
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
|
||||||
TruncationParams, TruncationStrategy,
|
TruncationParams, TruncationStrategy,
|
||||||
};
|
};
|
||||||
use tk::utils::iter::ResultShunt;
|
|
||||||
use tokenizers as tk;
|
use tokenizers as tk;
|
||||||
|
|
||||||
use super::decoders::PyDecoder;
|
use super::decoders::PyDecoder;
|
||||||
@ -1085,30 +1084,23 @@ impl PyTokenizer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[args(trainer = "None")]
|
#[args(trainer = "None", length = "None")]
|
||||||
fn train_from_iterator(
|
fn train_from_iterator(
|
||||||
&mut self,
|
&mut self,
|
||||||
iterator: &PyAny,
|
iterator: &PyAny,
|
||||||
trainer: Option<&mut PyTrainer>,
|
trainer: Option<&mut PyTrainer>,
|
||||||
|
length: Option<usize>,
|
||||||
) -> PyResult<()> {
|
) -> PyResult<()> {
|
||||||
|
use crate::utils::PySendIterator;
|
||||||
|
|
||||||
let mut trainer =
|
let mut trainer =
|
||||||
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
|
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
|
||||||
let (send, recv) = std::sync::mpsc::sync_channel(256);
|
|
||||||
let mut sender = Some(send);
|
|
||||||
let iterator: PyIterator = iterator.iter()?;
|
|
||||||
|
|
||||||
crossbeam::thread::scope(|s| {
|
let py_send = PySendIterator::new(
|
||||||
let _train_handle = s.spawn(|_| {
|
|
||||||
self.tokenizer
|
|
||||||
.train(&mut trainer, recv.into_iter())
|
|
||||||
.map(|_| {})
|
|
||||||
});
|
|
||||||
|
|
||||||
ResultShunt::process(
|
|
||||||
// Each element of the iterator can either be:
|
// Each element of the iterator can either be:
|
||||||
// - An iterator, to allow batching
|
// - An iterator, to allow batching
|
||||||
// - A string
|
// - A string
|
||||||
iterator.flat_map(|seq| match seq {
|
iterator.iter()?.flat_map(|seq| match seq {
|
||||||
Ok(s) => {
|
Ok(s) => {
|
||||||
if let Ok(iter) = s.iter() {
|
if let Ok(iter) = s.iter() {
|
||||||
itertools::Either::Left(iter.map(|i| i?.extract::<&str>()))
|
itertools::Either::Left(iter.map(|i| i?.extract::<&str>()))
|
||||||
@ -1118,18 +1110,15 @@ impl PyTokenizer {
|
|||||||
}
|
}
|
||||||
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
|
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
|
||||||
}),
|
}),
|
||||||
|iter| {
|
length,
|
||||||
if let Some(send) = sender.take() {
|
);
|
||||||
for seq in iter {
|
|
||||||
send.send(seq)
|
py_send.execute(|iter| {
|
||||||
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
|
self.tokenizer
|
||||||
}
|
.train(&mut trainer, iter)
|
||||||
}
|
.map(|_| {})
|
||||||
Ok(())
|
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
|
||||||
},
|
|
||||||
)?
|
|
||||||
})
|
})
|
||||||
.unwrap()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply all the post-processing steps to the given encodings.
|
/// Apply all the post-processing steps to the given encodings.
|
||||||
|
@ -12,6 +12,75 @@ pub use normalization::*;
|
|||||||
pub use pretokenization::*;
|
pub use pretokenization::*;
|
||||||
pub use regex::*;
|
pub use regex::*;
|
||||||
|
|
||||||
|
// PySendIterator
|
||||||
|
|
||||||
|
use std::sync::mpsc::{sync_channel, IntoIter};
|
||||||
|
use tk::utils::iter::ResultShunt;
|
||||||
|
|
||||||
|
pub struct MaybeSizedIterator<I> {
|
||||||
|
length: Option<usize>,
|
||||||
|
iter: I,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> Iterator for MaybeSizedIterator<I>
|
||||||
|
where
|
||||||
|
I: Iterator,
|
||||||
|
{
|
||||||
|
type Item = I::Item;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
self.iter.next()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||||
|
(self.length.unwrap_or(0), None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PySendIterator<I: Iterator> {
|
||||||
|
iter: I,
|
||||||
|
length: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I, T> PySendIterator<I>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = PyResult<T>>,
|
||||||
|
T: Send,
|
||||||
|
{
|
||||||
|
pub fn new(iter: I, length: Option<usize>) -> Self {
|
||||||
|
PySendIterator { iter, length }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute<F>(self, mut scope: F) -> PyResult<()>
|
||||||
|
where
|
||||||
|
F: FnMut(MaybeSizedIterator<IntoIter<T>>) -> PyResult<()> + Send + Sync,
|
||||||
|
{
|
||||||
|
let (send, recv) = sync_channel(256);
|
||||||
|
let mut sender = Some(send);
|
||||||
|
|
||||||
|
crossbeam::thread::scope(|s| {
|
||||||
|
let length = self.length;
|
||||||
|
s.spawn(move |_| {
|
||||||
|
scope(MaybeSizedIterator {
|
||||||
|
length,
|
||||||
|
iter: recv.into_iter(),
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
ResultShunt::process(self.iter, |iter| {
|
||||||
|
if let Some(send) = sender.take() {
|
||||||
|
for i in iter {
|
||||||
|
send.send(i)
|
||||||
|
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
})?
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// PyChar
|
// PyChar
|
||||||
// This type is a temporary hack to accept `char` as argument
|
// This type is a temporary hack to accept `char` as argument
|
||||||
// To be removed once https://github.com/PyO3/pyo3/pull/1282 has been released
|
// To be removed once https://github.com/PyO3/pyo3/pull/1282 has been released
|
||||||
|
@ -23,6 +23,7 @@ use serde::de::DeserializeOwned;
|
|||||||
use serde::export::Formatter;
|
use serde::export::Formatter;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::utils::iter::ResultShunt;
|
||||||
use crate::utils::parallelism::*;
|
use crate::utils::parallelism::*;
|
||||||
use crate::utils::progress::{ProgressBar, ProgressStyle};
|
use crate::utils::progress::{ProgressBar, ProgressStyle};
|
||||||
|
|
||||||
@ -124,7 +125,7 @@ pub trait Decoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
|
/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
|
||||||
/// and it returns a `Model` when done.
|
/// and then it can train the given `Model`.
|
||||||
pub trait Trainer {
|
pub trait Trainer {
|
||||||
type Model: Model + Sized;
|
type Model: Model + Sized;
|
||||||
/// Whether we should show progress during the training.
|
/// Whether we should show progress during the training.
|
||||||
@ -132,7 +133,8 @@ pub trait Trainer {
|
|||||||
/// The actual training method. This will return a new trained Model as well as a list
|
/// The actual training method. This will return a new trained Model as well as a list
|
||||||
/// of `special_tokens` to be added directly to the tokenizer along with the model.
|
/// of `special_tokens` to be added directly to the tokenizer along with the model.
|
||||||
fn train(&self, model: &mut Self::Model) -> Result<Vec<AddedToken>>;
|
fn train(&self, model: &mut Self::Model) -> Result<Vec<AddedToken>>;
|
||||||
/// Process an iterator of sequences already pre-processed by the Tokenizer
|
/// Process an iterator of sequences, calling `process` for each of them in order to
|
||||||
|
/// pre-process the said sequence as relevant.
|
||||||
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
|
||||||
where
|
where
|
||||||
I: Iterator<Item = S> + Send,
|
I: Iterator<Item = S> + Send,
|
||||||
@ -962,12 +964,20 @@ where
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Train our Model from files
|
||||||
pub fn train_from_files<T>(&mut self, trainer: &mut T, files: Vec<String>) -> Result<&mut Self>
|
pub fn train_from_files<T>(&mut self, trainer: &mut T, files: Vec<String>) -> Result<&mut Self>
|
||||||
where
|
where
|
||||||
T: Trainer<Model = M> + Sync,
|
T: Trainer<Model = M> + Sync,
|
||||||
{
|
{
|
||||||
|
let mut len = 0;
|
||||||
|
for file in files.iter() {
|
||||||
|
len += File::open(file)
|
||||||
|
.and_then(|f| f.metadata())
|
||||||
|
.map(|m| m.len())?;
|
||||||
|
}
|
||||||
|
|
||||||
let max_read = 1_000_000;
|
let max_read = 1_000_000;
|
||||||
use crate::utils::iter::ResultShunt;
|
|
||||||
ResultShunt::process(
|
ResultShunt::process(
|
||||||
files.into_iter().flat_map(|filename| {
|
files.into_iter().flat_map(|filename| {
|
||||||
match File::open(filename) {
|
match File::open(filename) {
|
||||||
@ -981,12 +991,52 @@ where
|
|||||||
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
|
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
|iter| self.train(trainer, iter).map(|_| {}),
|
|sequences| -> Result<()> {
|
||||||
|
let progress = if trainer.should_show_progress() {
|
||||||
|
let progress = ProgressBar::new(len);
|
||||||
|
progress.set_style(
|
||||||
|
ProgressStyle::default_bar()
|
||||||
|
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {percent:>18!}%"),
|
||||||
|
);
|
||||||
|
progress
|
||||||
|
.set_message(&format!("Pre-processing files ({:.2} Mo)", len / 1_000_000));
|
||||||
|
progress.set_draw_delta(len / 100); // Redraw only every 2%
|
||||||
|
Some(progress)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
trainer.feed(
|
||||||
|
sequences.map(|s| {
|
||||||
|
if let Some(progress) = &progress {
|
||||||
|
progress.inc(s.len() as u64)
|
||||||
|
}
|
||||||
|
s
|
||||||
|
}),
|
||||||
|
|seq| {
|
||||||
|
let normalized = self.do_normalize(seq.as_ref())?;
|
||||||
|
let pre_tokenized = self.do_pre_tokenize(normalized)?;
|
||||||
|
Ok(pre_tokenized
|
||||||
|
.get_splits(OffsetReferential::Original, OffsetType::Byte)
|
||||||
|
.into_iter()
|
||||||
|
.map(|(s, _, _)| s.to_owned())
|
||||||
|
.collect())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
if let Some(pbar) = progress {
|
||||||
|
pbar.finish();
|
||||||
|
}
|
||||||
|
let special_tokens = trainer.train(&mut self.model)?;
|
||||||
|
self.add_special_tokens(&special_tokens);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
)??;
|
)??;
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Train a model and replace our current Model, using the given Trainer
|
/// Train our Model, using the given Trainer and iterator
|
||||||
pub fn train<T, I, S>(&mut self, trainer: &mut T, sequences: I) -> Result<&mut Self>
|
pub fn train<T, I, S>(&mut self, trainer: &mut T, sequences: I) -> Result<&mut Self>
|
||||||
where
|
where
|
||||||
T: Trainer<Model = M> + Sync,
|
T: Trainer<Model = M> + Sync,
|
||||||
@ -1002,7 +1052,12 @@ where
|
|||||||
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
|
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
|
||||||
);
|
);
|
||||||
progress.set_message("Pre-processing sequences");
|
progress.set_message("Pre-processing sequences");
|
||||||
|
if len > 0 {
|
||||||
progress.set_draw_delta(len / 100); // Redraw only every 2%
|
progress.set_draw_delta(len / 100); // Redraw only every 2%
|
||||||
|
} else {
|
||||||
|
// Trying to have a good default to avoid progress tracking being the bottleneck
|
||||||
|
progress.set_draw_delta(1000);
|
||||||
|
}
|
||||||
Some(progress)
|
Some(progress)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -1010,9 +1065,9 @@ where
|
|||||||
|
|
||||||
trainer.feed(
|
trainer.feed(
|
||||||
sequences.map(|s| {
|
sequences.map(|s| {
|
||||||
// if let Some(progress) = &progress {
|
if let Some(progress) = &progress {
|
||||||
// progress.inc(1)
|
progress.inc(1)
|
||||||
// }
|
}
|
||||||
s
|
s
|
||||||
}),
|
}),
|
||||||
|seq| {
|
|seq| {
|
||||||
|
Reference in New Issue
Block a user