Steps of the pipeline can fail

This commit is contained in:
Anthony MOI
2019-12-11 07:18:38 -05:00
parent 7cb2fe2ea0
commit a929a99e05

View File

@ -15,26 +15,27 @@
use rayon::prelude::*;
use std::{
collections::HashMap,
error::Error,
fs::File,
io::{BufRead, BufReader},
};
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
/// A Normalizer takes care of pre-processing strings
pub trait Normalizer {
fn normalize(&self, s: String) -> String;
fn normalize(&self, s: String) -> Result<String>;
}
/// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model
pub trait PreTokenizer {
// TODO: Should return offsets with each substring
fn pre_tokenize(&self, s: &str) -> Vec<String>;
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>>;
}
/// Represents a `Model` used during Tokenization (Like BPE or Word or Unigram)
pub trait Model {
fn tokenize(&self, tokens: Vec<String>) -> Vec<Token>;
fn decode(&self, ids: Vec<u32>) -> Vec<String>;
fn tokenize(&self, tokens: Vec<String>) -> Result<Vec<Token>>;
fn decode(&self, ids: Vec<u32>) -> Result<Vec<String>>;
fn token_to_id(&self, token: &str) -> Option<u32>;
fn id_to_token(&self, id: u32) -> Option<String>;
fn get_vocab_size(&self) -> usize;
@ -43,18 +44,18 @@ pub trait Model {
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
/// Truncating, Padding, etc... are PostProcessor steps
pub trait PostProcessor {
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding;
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding>;
}
/// A Decoder has the responsibility to merge the given Vec<String> in a String
pub trait Decoder {
fn decode(&self, tokens: Vec<String>) -> String;
fn decode(&self, tokens: Vec<String>) -> Result<String>;
}
/// A Trainer has the responsibility to train a Model. We feed it with lines/sentences
/// and it returns a Model when done.
pub trait Trainer: Sync {
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>, Box<dyn Error>>;
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>>;
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
}
@ -82,6 +83,24 @@ pub struct Encoding {
offsets: Vec<(usize, usize)>,
}
impl Encoding {
pub fn new(
original: String,
normalized: String,
ids: Vec<u32>,
type_ids: Vec<u32>,
tokens: Vec<String>,
offsets: Vec<(usize, usize)>,
) -> Self {
Encoding {
original,
normalized,
ids,
type_ids,
tokens,
offsets,
}
}
pub fn get_original(&self) -> &str {
&self.original
}
@ -183,19 +202,19 @@ impl Tokenizer {
}
/// Encode the given sentence
pub fn encode(&self, input: EncodeInput) -> Encoding {
let generate_output = move |sentence: String, type_id: u32| -> Encoding {
pub fn encode(&self, input: EncodeInput) -> Result<Encoding> {
let generate_output = move |sentence: String, type_id: u32| -> Result<Encoding> {
// 1. Normalization
// TODO: Make sure we have the offsets update necessary to go from the original text to
// the normalized one
let original = sentence.clone();
let normalized = self.normalize(sentence);
let normalized = self.normalize(sentence)?;
// 2. Pre tokenization
let pre_tokenized = self.pre_tokenize(&normalized);
let pre_tokenized = self.pre_tokenize(&normalized)?;
// 3. Model
let output = self.model.tokenize(pre_tokenized);
let output = self.model.tokenize(pre_tokenized)?;
let length = output.len();
let (ids, tokens, offsets) = output.into_iter().fold(
@ -212,14 +231,14 @@ impl Tokenizer {
},
);
Encoding {
Ok(Encoding {
original,
normalized,
ids,
type_ids: vec![type_id; length],
tokens,
offsets,
}
})
};
let (sentence, pair) = match input {
@ -227,15 +246,18 @@ impl Tokenizer {
EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
};
let encoding = generate_output(sentence, 0);
let pair_encoding = pair.map(|pair| generate_output(pair, 1));
let encoding = generate_output(sentence, 0)?;
let pair_encoding = match pair {
Some(pair) => Some(generate_output(pair, 1)?),
None => None,
};
// 4. Post processing
self.post_process(encoding, pair_encoding)
}
/// Encode all the sentences in parallel, using multiple threads
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Vec<Encoding> {
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Result<Vec<Encoding>> {
inputs
.into_par_iter()
.map(|input| self.encode(input))
@ -243,18 +265,18 @@ impl Tokenizer {
}
/// Decode the given ids, back to a String
pub fn decode(&self, ids: Vec<u32>) -> String {
let tokens = self.model.decode(ids);
pub fn decode(&self, ids: Vec<u32>) -> Result<String> {
let tokens = self.model.decode(ids)?;
if let Some(decoder) = &self.decoder {
decoder.decode(tokens)
} else {
tokens.join(" ")
Ok(tokens.join(" "))
}
}
/// Decode all sentences in parallel
pub fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> Vec<String> {
pub fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> Result<Vec<String>> {
sentences
.into_par_iter()
.map(|sentence| self.decode(sentence))
@ -262,14 +284,10 @@ impl Tokenizer {
}
/// Train a model and replace our current Model, using the given Trainer
pub fn train(
&mut self,
trainer: &Box<dyn Trainer>,
files: Vec<String>,
) -> Result<(), Box<dyn Error>> {
pub fn train(&mut self, trainer: &Box<dyn Trainer>, files: Vec<String>) -> Result<()> {
let results = files
.par_iter()
.map(|file| -> std::io::Result<HashMap<String, u32>> {
.map(|file| -> Result<HashMap<String, u32>> {
let mut words = HashMap::new();
let file: std::fs::File = File::open(file)?;
@ -277,8 +295,8 @@ impl Tokenizer {
for line in file.lines() {
let line = line?;
let normalized = self.normalize(line);
let pre_tokenized = self.pre_tokenize(&normalized);
let normalized = self.normalize(line)?;
let pre_tokenized = self.pre_tokenize(&normalized)?;
trainer.process_tokens(&mut words, pre_tokenized);
}
@ -302,30 +320,34 @@ impl Tokenizer {
}
/// PreTokenization logic, handling the case where there is no PreTokenizer set
fn pre_tokenize(&self, sentence: &str) -> Vec<String> {
fn pre_tokenize(&self, sentence: &str) -> Result<Vec<String>> {
match &self.pre_tokenizer {
None => vec![sentence.to_owned()],
None => Ok(vec![sentence.to_owned()]),
Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence),
}
}
/// Normalization logic, go through all normalizers
fn normalize(&self, sentence: String) -> String {
fn normalize(&self, sentence: String) -> Result<String> {
if let Some(normalizer) = &self.normalizer {
normalizer.normalize(sentence)
} else {
sentence.to_owned()
Ok(sentence.to_owned())
}
}
/// Post processing logic, handling the case where there is no PostProcessor set
fn post_process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding {
fn post_process(
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,
) -> Result<Encoding> {
if let Some(processor) = &self.post_processor {
processor.process(encoding, pair_encoding)
} else {
match pair_encoding {
None => encoding,
Some(pair) => Encoding {
None => Ok(encoding),
Some(pair) => Ok(Encoding {
original: format!("{}{}", encoding.original, pair.original),
normalized: format!("{}{}", encoding.normalized, pair.normalized),
ids: [&encoding.ids[..], &pair.ids[..]].concat(),
@ -345,7 +367,7 @@ impl Tokenizer {
.collect::<Vec<_>>(),
]
.concat(),
},
}),
}
}
}