Add Trainer trait and Tokenizer.train

This commit is contained in:
Anthony MOI
2019-12-03 15:38:45 -05:00
parent 768eb9b920
commit 466555bade

View File

@ -13,9 +13,16 @@
//! ...)
//!
use rayon::prelude::*;
use std::{
collections::HashMap,
error::Error,
fs::File,
io::{BufRead, BufReader},
};
/// A Normalizer takes care of pre-processing strings
pub trait Normalizer {
// TODO: Use Cow here to avoid useless allocation if nothing is modified
fn normalize(&self, s: &str) -> String;
}
@ -43,6 +50,13 @@ pub trait Decoder {
fn decode(&self, tokens: Vec<String>) -> String;
}
/// A Trainer has the responsibility to train a Model. We feed it with lines/sentences
/// and it returns a Model when done.
pub trait Trainer: Sync {
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>, Box<dyn Error>>;
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
}
/// A Token represents the output of the Tokenizer
#[derive(Debug, PartialEq)]
pub struct Token {
@ -127,10 +141,8 @@ impl Tokenizer {
/// Encode the given sentence
pub fn encode(&self, sentence: &str) -> Vec<Token> {
let pre_tokenized = match &self.pre_tokenizer {
None => vec![sentence.to_owned()],
Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence),
};
let normalized = self.normalize(sentence);
let pre_tokenized = self.pre_tokenize(&normalized);
self.model.tokenize(pre_tokenized)
}
@ -161,4 +173,61 @@ impl Tokenizer {
.map(|sentence| self.decode(sentence))
.collect()
}
/// Train a model and replace our current Model, using the given Trainer
pub fn train(
&mut self,
trainer: &Box<dyn Trainer>,
files: Vec<String>,
) -> Result<(), Box<dyn Error>> {
let results = files
.par_iter()
.map(|file| -> std::io::Result<HashMap<String, u32>> {
let mut words = HashMap::new();
let file: std::fs::File = File::open(file)?;
let file = BufReader::new(file);
for line in file.lines() {
let line = line?;
let normalized = self.normalize(&line);
let pre_tokenized = self.pre_tokenize(&normalized);
trainer.process_tokens(&mut words, pre_tokenized);
}
Ok(words)
})
.collect::<Vec<_>>();
let mut words = HashMap::new();
for result in results {
for (word, count) in result? {
words
.entry(word)
.and_modify(|c| *c += count)
.or_insert(count);
}
}
self.model = trainer.train(words)?;
Ok(())
}
/// PreTokenization logic, handling the case where there is no PreTokenizer set
fn pre_tokenize(&self, sentence: &str) -> Vec<String> {
match &self.pre_tokenizer {
None => vec![sentence.to_owned()],
Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence),
}
}
/// Normalization logic, go through all normalizers
fn normalize(&self, sentence: &str) -> String {
if self.normalizers.len() == 0 {
sentence.to_owned()
} else {
unimplemented!("Normalization has not been implemented yet")
}
}
}