mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add Trainer trait and Tokenizer.train
This commit is contained in:
@ -13,9 +13,16 @@
|
||||
//! ...)
|
||||
//!
|
||||
use rayon::prelude::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
error::Error,
|
||||
fs::File,
|
||||
io::{BufRead, BufReader},
|
||||
};
|
||||
|
||||
/// A Normalizer takes care of pre-processing strings
|
||||
pub trait Normalizer {
|
||||
// TODO: Use Cow here to avoid useless allocation if nothing is modified
|
||||
fn normalize(&self, s: &str) -> String;
|
||||
}
|
||||
|
||||
@ -43,6 +50,13 @@ pub trait Decoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> String;
|
||||
}
|
||||
|
||||
/// A Trainer has the responsibility to train a Model. We feed it with lines/sentences
|
||||
/// and it returns a Model when done.
|
||||
pub trait Trainer: Sync {
|
||||
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>, Box<dyn Error>>;
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
|
||||
}
|
||||
|
||||
/// A Token represents the output of the Tokenizer
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct Token {
|
||||
@ -127,10 +141,8 @@ impl Tokenizer {
|
||||
|
||||
/// Encode the given sentence
|
||||
pub fn encode(&self, sentence: &str) -> Vec<Token> {
|
||||
let pre_tokenized = match &self.pre_tokenizer {
|
||||
None => vec![sentence.to_owned()],
|
||||
Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence),
|
||||
};
|
||||
let normalized = self.normalize(sentence);
|
||||
let pre_tokenized = self.pre_tokenize(&normalized);
|
||||
|
||||
self.model.tokenize(pre_tokenized)
|
||||
}
|
||||
@ -161,4 +173,61 @@ impl Tokenizer {
|
||||
.map(|sentence| self.decode(sentence))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Train a model and replace our current Model, using the given Trainer
|
||||
pub fn train(
|
||||
&mut self,
|
||||
trainer: &Box<dyn Trainer>,
|
||||
files: Vec<String>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let results = files
|
||||
.par_iter()
|
||||
.map(|file| -> std::io::Result<HashMap<String, u32>> {
|
||||
let mut words = HashMap::new();
|
||||
|
||||
let file: std::fs::File = File::open(file)?;
|
||||
let file = BufReader::new(file);
|
||||
|
||||
for line in file.lines() {
|
||||
let line = line?;
|
||||
let normalized = self.normalize(&line);
|
||||
let pre_tokenized = self.pre_tokenize(&normalized);
|
||||
trainer.process_tokens(&mut words, pre_tokenized);
|
||||
}
|
||||
|
||||
Ok(words)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut words = HashMap::new();
|
||||
for result in results {
|
||||
for (word, count) in result? {
|
||||
words
|
||||
.entry(word)
|
||||
.and_modify(|c| *c += count)
|
||||
.or_insert(count);
|
||||
}
|
||||
}
|
||||
|
||||
self.model = trainer.train(words)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// PreTokenization logic, handling the case where there is no PreTokenizer set
|
||||
fn pre_tokenize(&self, sentence: &str) -> Vec<String> {
|
||||
match &self.pre_tokenizer {
|
||||
None => vec![sentence.to_owned()],
|
||||
Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence),
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalization logic, go through all normalizers
|
||||
fn normalize(&self, sentence: &str) -> String {
|
||||
if self.normalizers.len() == 0 {
|
||||
sentence.to_owned()
|
||||
} else {
|
||||
unimplemented!("Normalization has not been implemented yet")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user