mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Steps of the pipeline can fail
This commit is contained in:
@ -15,26 +15,27 @@
|
||||
use rayon::prelude::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
error::Error,
|
||||
fs::File,
|
||||
io::{BufRead, BufReader},
|
||||
};
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
/// A Normalizer takes care of pre-processing strings
|
||||
pub trait Normalizer {
|
||||
fn normalize(&self, s: String) -> String;
|
||||
fn normalize(&self, s: String) -> Result<String>;
|
||||
}
|
||||
|
||||
/// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model
|
||||
pub trait PreTokenizer {
|
||||
// TODO: Should return offsets with each substring
|
||||
fn pre_tokenize(&self, s: &str) -> Vec<String>;
|
||||
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>>;
|
||||
}
|
||||
|
||||
/// Represents a `Model` used during Tokenization (Like BPE or Word or Unigram)
|
||||
pub trait Model {
|
||||
fn tokenize(&self, tokens: Vec<String>) -> Vec<Token>;
|
||||
fn decode(&self, ids: Vec<u32>) -> Vec<String>;
|
||||
fn tokenize(&self, tokens: Vec<String>) -> Result<Vec<Token>>;
|
||||
fn decode(&self, ids: Vec<u32>) -> Result<Vec<String>>;
|
||||
fn token_to_id(&self, token: &str) -> Option<u32>;
|
||||
fn id_to_token(&self, id: u32) -> Option<String>;
|
||||
fn get_vocab_size(&self) -> usize;
|
||||
@ -43,18 +44,18 @@ pub trait Model {
|
||||
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
|
||||
/// Truncating, Padding, etc... are PostProcessor steps
|
||||
pub trait PostProcessor {
|
||||
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding;
|
||||
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding>;
|
||||
}
|
||||
|
||||
/// A Decoder has the responsibility to merge the given Vec<String> in a String
|
||||
pub trait Decoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> String;
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String>;
|
||||
}
|
||||
|
||||
/// A Trainer has the responsibility to train a Model. We feed it with lines/sentences
|
||||
/// and it returns a Model when done.
|
||||
pub trait Trainer: Sync {
|
||||
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>, Box<dyn Error>>;
|
||||
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>>;
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
|
||||
}
|
||||
|
||||
@ -82,6 +83,24 @@ pub struct Encoding {
|
||||
offsets: Vec<(usize, usize)>,
|
||||
}
|
||||
impl Encoding {
|
||||
pub fn new(
|
||||
original: String,
|
||||
normalized: String,
|
||||
ids: Vec<u32>,
|
||||
type_ids: Vec<u32>,
|
||||
tokens: Vec<String>,
|
||||
offsets: Vec<(usize, usize)>,
|
||||
) -> Self {
|
||||
Encoding {
|
||||
original,
|
||||
normalized,
|
||||
ids,
|
||||
type_ids,
|
||||
tokens,
|
||||
offsets,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_original(&self) -> &str {
|
||||
&self.original
|
||||
}
|
||||
@ -183,19 +202,19 @@ impl Tokenizer {
|
||||
}
|
||||
|
||||
/// Encode the given sentence
|
||||
pub fn encode(&self, input: EncodeInput) -> Encoding {
|
||||
let generate_output = move |sentence: String, type_id: u32| -> Encoding {
|
||||
pub fn encode(&self, input: EncodeInput) -> Result<Encoding> {
|
||||
let generate_output = move |sentence: String, type_id: u32| -> Result<Encoding> {
|
||||
// 1. Normalization
|
||||
// TODO: Make sure we have the offsets update necessary to go from the original text to
|
||||
// the normalized one
|
||||
let original = sentence.clone();
|
||||
let normalized = self.normalize(sentence);
|
||||
let normalized = self.normalize(sentence)?;
|
||||
|
||||
// 2. Pre tokenization
|
||||
let pre_tokenized = self.pre_tokenize(&normalized);
|
||||
let pre_tokenized = self.pre_tokenize(&normalized)?;
|
||||
|
||||
// 3. Model
|
||||
let output = self.model.tokenize(pre_tokenized);
|
||||
let output = self.model.tokenize(pre_tokenized)?;
|
||||
let length = output.len();
|
||||
|
||||
let (ids, tokens, offsets) = output.into_iter().fold(
|
||||
@ -212,14 +231,14 @@ impl Tokenizer {
|
||||
},
|
||||
);
|
||||
|
||||
Encoding {
|
||||
Ok(Encoding {
|
||||
original,
|
||||
normalized,
|
||||
ids,
|
||||
type_ids: vec![type_id; length],
|
||||
tokens,
|
||||
offsets,
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
let (sentence, pair) = match input {
|
||||
@ -227,15 +246,18 @@ impl Tokenizer {
|
||||
EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
|
||||
};
|
||||
|
||||
let encoding = generate_output(sentence, 0);
|
||||
let pair_encoding = pair.map(|pair| generate_output(pair, 1));
|
||||
let encoding = generate_output(sentence, 0)?;
|
||||
let pair_encoding = match pair {
|
||||
Some(pair) => Some(generate_output(pair, 1)?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
// 4. Post processing
|
||||
self.post_process(encoding, pair_encoding)
|
||||
}
|
||||
|
||||
/// Encode all the sentences in parallel, using multiple threads
|
||||
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Vec<Encoding> {
|
||||
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Result<Vec<Encoding>> {
|
||||
inputs
|
||||
.into_par_iter()
|
||||
.map(|input| self.encode(input))
|
||||
@ -243,18 +265,18 @@ impl Tokenizer {
|
||||
}
|
||||
|
||||
/// Decode the given ids, back to a String
|
||||
pub fn decode(&self, ids: Vec<u32>) -> String {
|
||||
let tokens = self.model.decode(ids);
|
||||
pub fn decode(&self, ids: Vec<u32>) -> Result<String> {
|
||||
let tokens = self.model.decode(ids)?;
|
||||
|
||||
if let Some(decoder) = &self.decoder {
|
||||
decoder.decode(tokens)
|
||||
} else {
|
||||
tokens.join(" ")
|
||||
Ok(tokens.join(" "))
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode all sentences in parallel
|
||||
pub fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> Vec<String> {
|
||||
pub fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> Result<Vec<String>> {
|
||||
sentences
|
||||
.into_par_iter()
|
||||
.map(|sentence| self.decode(sentence))
|
||||
@ -262,14 +284,10 @@ impl Tokenizer {
|
||||
}
|
||||
|
||||
/// Train a model and replace our current Model, using the given Trainer
|
||||
pub fn train(
|
||||
&mut self,
|
||||
trainer: &Box<dyn Trainer>,
|
||||
files: Vec<String>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
pub fn train(&mut self, trainer: &Box<dyn Trainer>, files: Vec<String>) -> Result<()> {
|
||||
let results = files
|
||||
.par_iter()
|
||||
.map(|file| -> std::io::Result<HashMap<String, u32>> {
|
||||
.map(|file| -> Result<HashMap<String, u32>> {
|
||||
let mut words = HashMap::new();
|
||||
|
||||
let file: std::fs::File = File::open(file)?;
|
||||
@ -277,8 +295,8 @@ impl Tokenizer {
|
||||
|
||||
for line in file.lines() {
|
||||
let line = line?;
|
||||
let normalized = self.normalize(line);
|
||||
let pre_tokenized = self.pre_tokenize(&normalized);
|
||||
let normalized = self.normalize(line)?;
|
||||
let pre_tokenized = self.pre_tokenize(&normalized)?;
|
||||
trainer.process_tokens(&mut words, pre_tokenized);
|
||||
}
|
||||
|
||||
@ -302,30 +320,34 @@ impl Tokenizer {
|
||||
}
|
||||
|
||||
/// PreTokenization logic, handling the case where there is no PreTokenizer set
|
||||
fn pre_tokenize(&self, sentence: &str) -> Vec<String> {
|
||||
fn pre_tokenize(&self, sentence: &str) -> Result<Vec<String>> {
|
||||
match &self.pre_tokenizer {
|
||||
None => vec![sentence.to_owned()],
|
||||
None => Ok(vec![sentence.to_owned()]),
|
||||
Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence),
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalization logic, go through all normalizers
|
||||
fn normalize(&self, sentence: String) -> String {
|
||||
fn normalize(&self, sentence: String) -> Result<String> {
|
||||
if let Some(normalizer) = &self.normalizer {
|
||||
normalizer.normalize(sentence)
|
||||
} else {
|
||||
sentence.to_owned()
|
||||
Ok(sentence.to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
/// Post processing logic, handling the case where there is no PostProcessor set
|
||||
fn post_process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding {
|
||||
fn post_process(
|
||||
&self,
|
||||
encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
) -> Result<Encoding> {
|
||||
if let Some(processor) = &self.post_processor {
|
||||
processor.process(encoding, pair_encoding)
|
||||
} else {
|
||||
match pair_encoding {
|
||||
None => encoding,
|
||||
Some(pair) => Encoding {
|
||||
None => Ok(encoding),
|
||||
Some(pair) => Ok(Encoding {
|
||||
original: format!("{}{}", encoding.original, pair.original),
|
||||
normalized: format!("{}{}", encoding.normalized, pair.normalized),
|
||||
ids: [&encoding.ids[..], &pair.ids[..]].concat(),
|
||||
@ -345,7 +367,7 @@ impl Tokenizer {
|
||||
.collect::<Vec<_>>(),
|
||||
]
|
||||
.concat(),
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user