mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Add ability to train from Iterator
This commit is contained in:
@ -12,6 +12,7 @@ use tk::tokenizer::{
|
||||
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
|
||||
TruncationParams, TruncationStrategy,
|
||||
};
|
||||
use tk::utils::iter::ResultShunt;
|
||||
use tokenizers as tk;
|
||||
|
||||
use super::decoders::PyDecoder;
|
||||
@ -1069,16 +1070,53 @@ impl PyTokenizer {
|
||||
}
|
||||
|
||||
#[args(trainer = "None")]
|
||||
fn train(&mut self, files: Vec<String>, trainer: Option<&PyTrainer>) -> PyResult<()> {
|
||||
let trainer =
|
||||
fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> {
|
||||
let mut trainer =
|
||||
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| {
|
||||
ToPyResult(self.tokenizer.train(&trainer, files).map(|_| {})).into()
|
||||
ToPyResult(
|
||||
self.tokenizer
|
||||
.train_from_files(&mut trainer, files)
|
||||
.map(|_| {}),
|
||||
)
|
||||
.into()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[args(trainer = "None")]
|
||||
fn train_from_iterator(
|
||||
&mut self,
|
||||
iterator: &PyAny,
|
||||
trainer: Option<&mut PyTrainer>,
|
||||
) -> PyResult<()> {
|
||||
let mut trainer =
|
||||
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
|
||||
let (send, recv) = std::sync::mpsc::sync_channel(256);
|
||||
let mut sender = Some(send);
|
||||
let iterator: PyIterator = iterator.iter()?;
|
||||
|
||||
crossbeam::thread::scope(|s| {
|
||||
let _train_handle = s.spawn(|_| {
|
||||
self.tokenizer
|
||||
.train(&mut trainer, recv.into_iter())
|
||||
.map(|_| {})
|
||||
});
|
||||
|
||||
ResultShunt::process(iterator.map(|seq| seq?.extract::<&str>()), |iter| {
|
||||
if let Some(send) = sender.take() {
|
||||
for seq in iter {
|
||||
send.send(seq)
|
||||
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})?
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Apply all the post-processing steps to the given encodings.
|
||||
///
|
||||
/// The various steps are:
|
||||
|
Reference in New Issue
Block a user