Also accept iterators of batches in train_from_iterator

This commit is contained in:
Anthony MOI
2020-11-24 22:48:55 -05:00
committed by Anthony MOI
parent e0a70f1fb2
commit 75deaecdd0
3 changed files with 25 additions and 8 deletions

View File

@ -1080,6 +1080,7 @@ version = "0.9.4"
dependencies = [
"crossbeam",
"env_logger",
"itertools 0.9.0",
"libc",
"ndarray",
"numpy",

View File

@ -19,6 +19,7 @@ numpy = { git = "https://github.com/pyo3/rust-numpy/", rev = "e331befa27fede78d4
ndarray = "0.13"
onig = { version = "6.0", default-features = false }
crossbeam = "0.8"
itertools = "0.9"
[dependencies.tokenizers]
version = "*"

View File

@ -1104,15 +1104,30 @@ impl PyTokenizer {
.map(|_| {})
});
ResultShunt::process(iterator.map(|seq| seq?.extract::<&str>()), |iter| {
if let Some(send) = sender.take() {
for seq in iter {
send.send(seq)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
ResultShunt::process(
// Each element of the iterator can either be:
// - An iterator, to allow batching
// - A string
iterator.flat_map(|seq| match seq {
Ok(s) => {
if let Ok(iter) = s.iter() {
itertools::Either::Left(iter.map(|i| i?.extract::<&str>()))
} else {
itertools::Either::Right(std::iter::once(s.extract::<&str>()))
}
}
}
Ok(())
})?
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}),
|iter| {
if let Some(send) = sender.take() {
for seq in iter {
send.send(seq)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
}
}
Ok(())
},
)?
})
.unwrap()
}