Add ability to train from Iterator

This commit is contained in:
Anthony MOI
2020-11-12 12:58:14 -05:00
committed by Anthony MOI
parent 6e364cb685
commit e0a70f1fb2
11 changed files with 380 additions and 199 deletions

View File

@ -12,6 +12,7 @@ use tk::tokenizer::{
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
TruncationParams, TruncationStrategy,
};
use tk::utils::iter::ResultShunt;
use tokenizers as tk;
use super::decoders::PyDecoder;
@ -1069,16 +1070,53 @@ impl PyTokenizer {
}
#[args(trainer = "None")]
fn train(&mut self, files: Vec<String>, trainer: Option<&PyTrainer>) -> PyResult<()> {
let trainer =
fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> {
let mut trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
Python::with_gil(|py| {
py.allow_threads(|| {
ToPyResult(self.tokenizer.train(&trainer, files).map(|_| {})).into()
ToPyResult(
self.tokenizer
.train_from_files(&mut trainer, files)
.map(|_| {}),
)
.into()
})
})
}
#[args(trainer = "None")]
fn train_from_iterator(
&mut self,
iterator: &PyAny,
trainer: Option<&mut PyTrainer>,
) -> PyResult<()> {
let mut trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
let (send, recv) = std::sync::mpsc::sync_channel(256);
let mut sender = Some(send);
let iterator: PyIterator = iterator.iter()?;
crossbeam::thread::scope(|s| {
let _train_handle = s.spawn(|_| {
self.tokenizer
.train(&mut trainer, recv.into_iter())
.map(|_| {})
});
ResultShunt::process(iterator.map(|seq| seq?.extract::<&str>()), |iter| {
if let Some(send) = sender.take() {
for seq in iter {
send.send(seq)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
}
}
Ok(())
})?
})
.unwrap()
}
/// Apply all the post-processing steps to the given encodings.
///
/// The various steps are: