diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index dcc8ced0..e62d8de7 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -12,7 +12,6 @@ use tk::tokenizer::{ Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, TruncationParams, TruncationStrategy, }; -use tk::utils::iter::ResultShunt; use tokenizers as tk; use super::decoders::PyDecoder; @@ -1085,51 +1084,41 @@ impl PyTokenizer { }) } - #[args(trainer = "None")] + #[args(trainer = "None", length = "None")] fn train_from_iterator( &mut self, iterator: &PyAny, trainer: Option<&mut PyTrainer>, + length: Option, ) -> PyResult<()> { + use crate::utils::PySendIterator; + let mut trainer = trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone()); - let (send, recv) = std::sync::mpsc::sync_channel(256); - let mut sender = Some(send); - let iterator: PyIterator = iterator.iter()?; - crossbeam::thread::scope(|s| { - let _train_handle = s.spawn(|_| { - self.tokenizer - .train(&mut trainer, recv.into_iter()) - .map(|_| {}) - }); + let py_send = PySendIterator::new( + // Each element of the iterator can either be: + // - An iterator, to allow batching + // - A string + iterator.iter()?.flat_map(|seq| match seq { + Ok(s) => { + if let Ok(iter) = s.iter() { + itertools::Either::Left(iter.map(|i| i?.extract::<&str>())) + } else { + itertools::Either::Right(std::iter::once(s.extract::<&str>())) + } + } + Err(e) => itertools::Either::Right(std::iter::once(Err(e))), + }), + length, + ); - ResultShunt::process( - // Each element of the iterator can either be: - // - An iterator, to allow batching - // - A string - iterator.flat_map(|seq| match seq { - Ok(s) => { - if let Ok(iter) = s.iter() { - itertools::Either::Left(iter.map(|i| i?.extract::<&str>())) - } else { - itertools::Either::Right(std::iter::once(s.extract::<&str>())) - } - } - Err(e) => itertools::Either::Right(std::iter::once(Err(e))), - }), - |iter| { - if let Some(send) = sender.take() { - for seq in iter { - send.send(seq) - .map_err(|e| exceptions::PyException::new_err(e.to_string()))?; - } - } - Ok(()) - }, - )? + py_send.execute(|iter| { + self.tokenizer + .train(&mut trainer, iter) + .map(|_| {}) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) }) - .unwrap() } /// Apply all the post-processing steps to the given encodings. diff --git a/bindings/python/src/utils/mod.rs b/bindings/python/src/utils/mod.rs index fa351b9e..e74a1b3b 100644 --- a/bindings/python/src/utils/mod.rs +++ b/bindings/python/src/utils/mod.rs @@ -12,6 +12,75 @@ pub use normalization::*; pub use pretokenization::*; pub use regex::*; +// PySendIterator + +use std::sync::mpsc::{sync_channel, IntoIter}; +use tk::utils::iter::ResultShunt; + +pub struct MaybeSizedIterator { + length: Option, + iter: I, +} + +impl Iterator for MaybeSizedIterator +where + I: Iterator, +{ + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + (self.length.unwrap_or(0), None) + } +} + +pub struct PySendIterator { + iter: I, + length: Option, +} + +impl PySendIterator +where + I: Iterator>, + T: Send, +{ + pub fn new(iter: I, length: Option) -> Self { + PySendIterator { iter, length } + } + + pub fn execute(self, mut scope: F) -> PyResult<()> + where + F: FnMut(MaybeSizedIterator>) -> PyResult<()> + Send + Sync, + { + let (send, recv) = sync_channel(256); + let mut sender = Some(send); + + crossbeam::thread::scope(|s| { + let length = self.length; + s.spawn(move |_| { + scope(MaybeSizedIterator { + length, + iter: recv.into_iter(), + }) + }); + + ResultShunt::process(self.iter, |iter| { + if let Some(send) = sender.take() { + for i in iter { + send.send(i) + .map_err(|e| exceptions::PyException::new_err(e.to_string()))?; + } + } + Ok(()) + })? + }) + .unwrap() + } +} + // PyChar // This type is a temporary hack to accept `char` as argument // To be removed once https://github.com/PyO3/pyo3/pull/1282 has been released diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index f2764efc..341892ae 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -23,6 +23,7 @@ use serde::de::DeserializeOwned; use serde::export::Formatter; use serde::{Deserialize, Serialize}; +use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; @@ -124,7 +125,7 @@ pub trait Decoder { } /// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences -/// and it returns a `Model` when done. +/// and then it can train the given `Model`. pub trait Trainer { type Model: Model + Sized; /// Whether we should show progress during the training. @@ -132,7 +133,8 @@ pub trait Trainer { /// The actual training method. This will return a new trained Model as well as a list /// of `special_tokens` to be added directly to the tokenizer along with the model. fn train(&self, model: &mut Self::Model) -> Result>; - /// Process an iterator of sequences already pre-processed by the Tokenizer + /// Process an iterator of sequences, calling `process` for each of them in order to + /// pre-process the said sequence as relevant. fn feed(&mut self, iterator: I, process: F) -> Result<()> where I: Iterator + Send, @@ -962,12 +964,20 @@ where .collect() } + /// Train our Model from files pub fn train_from_files(&mut self, trainer: &mut T, files: Vec) -> Result<&mut Self> where T: Trainer + Sync, { + let mut len = 0; + for file in files.iter() { + len += File::open(file) + .and_then(|f| f.metadata()) + .map(|m| m.len())?; + } + let max_read = 1_000_000; - use crate::utils::iter::ResultShunt; + ResultShunt::process( files.into_iter().flat_map(|filename| { match File::open(filename) { @@ -981,12 +991,52 @@ where Err(e) => itertools::Either::Right(std::iter::once(Err(e))), } }), - |iter| self.train(trainer, iter).map(|_| {}), + |sequences| -> Result<()> { + let progress = if trainer.should_show_progress() { + let progress = ProgressBar::new(len); + progress.set_style( + ProgressStyle::default_bar() + .template("[{elapsed_precise}] {msg:<40!} {wide_bar} {percent:>18!}%"), + ); + progress + .set_message(&format!("Pre-processing files ({:.2} Mo)", len / 1_000_000)); + progress.set_draw_delta(len / 100); // Redraw only every 2% + Some(progress) + } else { + None + }; + + trainer.feed( + sequences.map(|s| { + if let Some(progress) = &progress { + progress.inc(s.len() as u64) + } + s + }), + |seq| { + let normalized = self.do_normalize(seq.as_ref())?; + let pre_tokenized = self.do_pre_tokenize(normalized)?; + Ok(pre_tokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, _, _)| s.to_owned()) + .collect()) + }, + )?; + + if let Some(pbar) = progress { + pbar.finish(); + } + let special_tokens = trainer.train(&mut self.model)?; + self.add_special_tokens(&special_tokens); + + Ok(()) + }, )??; Ok(self) } - /// Train a model and replace our current Model, using the given Trainer + /// Train our Model, using the given Trainer and iterator pub fn train(&mut self, trainer: &mut T, sequences: I) -> Result<&mut Self> where T: Trainer + Sync, @@ -1002,7 +1052,12 @@ where .template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"), ); progress.set_message("Pre-processing sequences"); - progress.set_draw_delta(len / 100); // Redraw only every 2% + if len > 0 { + progress.set_draw_delta(len / 100); // Redraw only every 2% + } else { + // Trying to have a good default to avoid progress tracking being the bottleneck + progress.set_draw_delta(1000); + } Some(progress) } else { None @@ -1010,9 +1065,9 @@ where trainer.feed( sequences.map(|s| { - // if let Some(progress) = &progress { - // progress.inc(1) - // } + if let Some(progress) = &progress { + progress.inc(1) + } s }), |seq| {