mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Add ability to train from Iterator
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use pyo3::exceptions;
|
||||
@ -50,19 +49,20 @@ impl Trainer for PyTrainer {
|
||||
self.trainer.read().unwrap().should_show_progress()
|
||||
}
|
||||
|
||||
fn train(
|
||||
&self,
|
||||
words: HashMap<String, u32>,
|
||||
model: &mut PyModel,
|
||||
) -> tk::Result<Vec<tk::AddedToken>> {
|
||||
fn train(&self, model: &mut PyModel) -> tk::Result<Vec<tk::AddedToken>> {
|
||||
self.trainer
|
||||
.read()
|
||||
.unwrap()
|
||||
.train(words, &mut model.model.write().unwrap())
|
||||
.train(&mut model.model.write().unwrap())
|
||||
}
|
||||
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
self.trainer.read().unwrap().process_tokens(words, tokens)
|
||||
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> tk::Result<()>
|
||||
where
|
||||
I: Iterator<Item = S> + Send,
|
||||
S: AsRef<str> + Send,
|
||||
F: Fn(&str) -> tk::Result<Vec<String>> + Sync,
|
||||
{
|
||||
self.trainer.write().unwrap().feed(iterator, process)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user