Python - Improve training with iterators

This commit is contained in:
Anthony MOI
2020-12-15 10:50:01 -05:00
committed by Anthony MOI
parent dad8d6249e
commit 5938a12b3f
5 changed files with 181 additions and 195 deletions

View File

@ -12,6 +12,7 @@ use tk::tokenizer::{
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
TruncationParams, TruncationStrategy,
};
use tk::utils::iter::ResultShunt;
use tokenizers as tk;
use super::decoders::PyDecoder;
@ -22,6 +23,7 @@ use super::normalizers::PyNormalizer;
use super::pre_tokenizers::PyPreTokenizer;
use super::trainers::PyTrainer;
use crate::processors::PyPostProcessor;
use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
/// Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`.
/// It can have special options that defines the way it should behave.
@ -1120,40 +1122,43 @@ impl PyTokenizer {
#[text_signature = "(self, iterator, trainer=None, length=None)"]
fn train_from_iterator(
&mut self,
py: Python,
iterator: &PyAny,
trainer: Option<&mut PyTrainer>,
length: Option<usize>,
) -> PyResult<()> {
use crate::utils::PySendIterator;
let mut trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
let py_send = PySendIterator::new(
// Each element of the iterator can either be:
// - An iterator, to allow batching
// - A string
iterator.iter()?.flat_map(|seq| match seq {
Ok(s) => {
if let Ok(s) = s.downcast::<PyString>() {
itertools::Either::Right(std::iter::once(s.to_str()))
} else {
match s.iter() {
Ok(iter) => itertools::Either::Left(iter.map(|i| i?.extract::<&str>())),
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}
let buffered_iter = PyBufferedIterator::new(
iterator,
|element| {
// Each element of the iterator can either be:
// - An iterator, to allow batching
// - A string
if let Ok(s) = element.downcast::<PyString>() {
itertools::Either::Right(std::iter::once(s.to_str().map(|s| s.to_owned())))
} else {
match element.iter() {
Ok(iter) => itertools::Either::Left(
iter.map(|i| i?.extract::<String>())
.collect::<Vec<_>>()
.into_iter(),
),
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}
}
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}),
length,
);
},
256,
)?;
py_send.execute(|iter| {
self.tokenizer
.train(&mut trainer, iter)
.map(|_| {})
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
py.allow_threads(|| {
ResultShunt::process(buffered_iter, |iter| {
self.tokenizer
.train(&mut trainer, MaybeSizedIterator::new(iter, length))
.map(|_| {})
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
})?
})
}