mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Also accept iterators of batches in train_from_iterator
This commit is contained in:
1
bindings/python/Cargo.lock
generated
1
bindings/python/Cargo.lock
generated
@ -1080,6 +1080,7 @@ version = "0.9.4"
|
||||
dependencies = [
|
||||
"crossbeam",
|
||||
"env_logger",
|
||||
"itertools 0.9.0",
|
||||
"libc",
|
||||
"ndarray",
|
||||
"numpy",
|
||||
|
@ -19,6 +19,7 @@ numpy = { git = "https://github.com/pyo3/rust-numpy/", rev = "e331befa27fede78d4
|
||||
ndarray = "0.13"
|
||||
onig = { version = "6.0", default-features = false }
|
||||
crossbeam = "0.8"
|
||||
itertools = "0.9"
|
||||
|
||||
[dependencies.tokenizers]
|
||||
version = "*"
|
||||
|
@ -1104,15 +1104,30 @@ impl PyTokenizer {
|
||||
.map(|_| {})
|
||||
});
|
||||
|
||||
ResultShunt::process(iterator.map(|seq| seq?.extract::<&str>()), |iter| {
|
||||
if let Some(send) = sender.take() {
|
||||
for seq in iter {
|
||||
send.send(seq)
|
||||
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
|
||||
ResultShunt::process(
|
||||
// Each element of the iterator can either be:
|
||||
// - An iterator, to allow batching
|
||||
// - A string
|
||||
iterator.flat_map(|seq| match seq {
|
||||
Ok(s) => {
|
||||
if let Ok(iter) = s.iter() {
|
||||
itertools::Either::Left(iter.map(|i| i?.extract::<&str>()))
|
||||
} else {
|
||||
itertools::Either::Right(std::iter::once(s.extract::<&str>()))
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})?
|
||||
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
|
||||
}),
|
||||
|iter| {
|
||||
if let Some(send) = sender.take() {
|
||||
for seq in iter {
|
||||
send.send(seq)
|
||||
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
)?
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
Reference in New Issue
Block a user