mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Steps of the pipeline can fail
This commit is contained in:
@ -15,26 +15,27 @@
|
|||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
error::Error,
|
|
||||||
fs::File,
|
fs::File,
|
||||||
io::{BufRead, BufReader},
|
io::{BufRead, BufReader},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
|
||||||
/// A Normalizer takes care of pre-processing strings
|
/// A Normalizer takes care of pre-processing strings
|
||||||
pub trait Normalizer {
|
pub trait Normalizer {
|
||||||
fn normalize(&self, s: String) -> String;
|
fn normalize(&self, s: String) -> Result<String>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model
|
/// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model
|
||||||
pub trait PreTokenizer {
|
pub trait PreTokenizer {
|
||||||
// TODO: Should return offsets with each substring
|
// TODO: Should return offsets with each substring
|
||||||
fn pre_tokenize(&self, s: &str) -> Vec<String>;
|
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Represents a `Model` used during Tokenization (Like BPE or Word or Unigram)
|
/// Represents a `Model` used during Tokenization (Like BPE or Word or Unigram)
|
||||||
pub trait Model {
|
pub trait Model {
|
||||||
fn tokenize(&self, tokens: Vec<String>) -> Vec<Token>;
|
fn tokenize(&self, tokens: Vec<String>) -> Result<Vec<Token>>;
|
||||||
fn decode(&self, ids: Vec<u32>) -> Vec<String>;
|
fn decode(&self, ids: Vec<u32>) -> Result<Vec<String>>;
|
||||||
fn token_to_id(&self, token: &str) -> Option<u32>;
|
fn token_to_id(&self, token: &str) -> Option<u32>;
|
||||||
fn id_to_token(&self, id: u32) -> Option<String>;
|
fn id_to_token(&self, id: u32) -> Option<String>;
|
||||||
fn get_vocab_size(&self) -> usize;
|
fn get_vocab_size(&self) -> usize;
|
||||||
@ -43,18 +44,18 @@ pub trait Model {
|
|||||||
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
|
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
|
||||||
/// Truncating, Padding, etc... are PostProcessor steps
|
/// Truncating, Padding, etc... are PostProcessor steps
|
||||||
pub trait PostProcessor {
|
pub trait PostProcessor {
|
||||||
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding;
|
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A Decoder has the responsibility to merge the given Vec<String> in a String
|
/// A Decoder has the responsibility to merge the given Vec<String> in a String
|
||||||
pub trait Decoder {
|
pub trait Decoder {
|
||||||
fn decode(&self, tokens: Vec<String>) -> String;
|
fn decode(&self, tokens: Vec<String>) -> Result<String>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A Trainer has the responsibility to train a Model. We feed it with lines/sentences
|
/// A Trainer has the responsibility to train a Model. We feed it with lines/sentences
|
||||||
/// and it returns a Model when done.
|
/// and it returns a Model when done.
|
||||||
pub trait Trainer: Sync {
|
pub trait Trainer: Sync {
|
||||||
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>, Box<dyn Error>>;
|
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>>;
|
||||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
|
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,6 +83,24 @@ pub struct Encoding {
|
|||||||
offsets: Vec<(usize, usize)>,
|
offsets: Vec<(usize, usize)>,
|
||||||
}
|
}
|
||||||
impl Encoding {
|
impl Encoding {
|
||||||
|
pub fn new(
|
||||||
|
original: String,
|
||||||
|
normalized: String,
|
||||||
|
ids: Vec<u32>,
|
||||||
|
type_ids: Vec<u32>,
|
||||||
|
tokens: Vec<String>,
|
||||||
|
offsets: Vec<(usize, usize)>,
|
||||||
|
) -> Self {
|
||||||
|
Encoding {
|
||||||
|
original,
|
||||||
|
normalized,
|
||||||
|
ids,
|
||||||
|
type_ids,
|
||||||
|
tokens,
|
||||||
|
offsets,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_original(&self) -> &str {
|
pub fn get_original(&self) -> &str {
|
||||||
&self.original
|
&self.original
|
||||||
}
|
}
|
||||||
@ -183,19 +202,19 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Encode the given sentence
|
/// Encode the given sentence
|
||||||
pub fn encode(&self, input: EncodeInput) -> Encoding {
|
pub fn encode(&self, input: EncodeInput) -> Result<Encoding> {
|
||||||
let generate_output = move |sentence: String, type_id: u32| -> Encoding {
|
let generate_output = move |sentence: String, type_id: u32| -> Result<Encoding> {
|
||||||
// 1. Normalization
|
// 1. Normalization
|
||||||
// TODO: Make sure we have the offsets update necessary to go from the original text to
|
// TODO: Make sure we have the offsets update necessary to go from the original text to
|
||||||
// the normalized one
|
// the normalized one
|
||||||
let original = sentence.clone();
|
let original = sentence.clone();
|
||||||
let normalized = self.normalize(sentence);
|
let normalized = self.normalize(sentence)?;
|
||||||
|
|
||||||
// 2. Pre tokenization
|
// 2. Pre tokenization
|
||||||
let pre_tokenized = self.pre_tokenize(&normalized);
|
let pre_tokenized = self.pre_tokenize(&normalized)?;
|
||||||
|
|
||||||
// 3. Model
|
// 3. Model
|
||||||
let output = self.model.tokenize(pre_tokenized);
|
let output = self.model.tokenize(pre_tokenized)?;
|
||||||
let length = output.len();
|
let length = output.len();
|
||||||
|
|
||||||
let (ids, tokens, offsets) = output.into_iter().fold(
|
let (ids, tokens, offsets) = output.into_iter().fold(
|
||||||
@ -212,14 +231,14 @@ impl Tokenizer {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
Encoding {
|
Ok(Encoding {
|
||||||
original,
|
original,
|
||||||
normalized,
|
normalized,
|
||||||
ids,
|
ids,
|
||||||
type_ids: vec![type_id; length],
|
type_ids: vec![type_id; length],
|
||||||
tokens,
|
tokens,
|
||||||
offsets,
|
offsets,
|
||||||
}
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
let (sentence, pair) = match input {
|
let (sentence, pair) = match input {
|
||||||
@ -227,15 +246,18 @@ impl Tokenizer {
|
|||||||
EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
|
EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
|
||||||
};
|
};
|
||||||
|
|
||||||
let encoding = generate_output(sentence, 0);
|
let encoding = generate_output(sentence, 0)?;
|
||||||
let pair_encoding = pair.map(|pair| generate_output(pair, 1));
|
let pair_encoding = match pair {
|
||||||
|
Some(pair) => Some(generate_output(pair, 1)?),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
// 4. Post processing
|
// 4. Post processing
|
||||||
self.post_process(encoding, pair_encoding)
|
self.post_process(encoding, pair_encoding)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encode all the sentences in parallel, using multiple threads
|
/// Encode all the sentences in parallel, using multiple threads
|
||||||
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Vec<Encoding> {
|
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Result<Vec<Encoding>> {
|
||||||
inputs
|
inputs
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
.map(|input| self.encode(input))
|
.map(|input| self.encode(input))
|
||||||
@ -243,18 +265,18 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Decode the given ids, back to a String
|
/// Decode the given ids, back to a String
|
||||||
pub fn decode(&self, ids: Vec<u32>) -> String {
|
pub fn decode(&self, ids: Vec<u32>) -> Result<String> {
|
||||||
let tokens = self.model.decode(ids);
|
let tokens = self.model.decode(ids)?;
|
||||||
|
|
||||||
if let Some(decoder) = &self.decoder {
|
if let Some(decoder) = &self.decoder {
|
||||||
decoder.decode(tokens)
|
decoder.decode(tokens)
|
||||||
} else {
|
} else {
|
||||||
tokens.join(" ")
|
Ok(tokens.join(" "))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decode all sentences in parallel
|
/// Decode all sentences in parallel
|
||||||
pub fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> Vec<String> {
|
pub fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> Result<Vec<String>> {
|
||||||
sentences
|
sentences
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
.map(|sentence| self.decode(sentence))
|
.map(|sentence| self.decode(sentence))
|
||||||
@ -262,14 +284,10 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Train a model and replace our current Model, using the given Trainer
|
/// Train a model and replace our current Model, using the given Trainer
|
||||||
pub fn train(
|
pub fn train(&mut self, trainer: &Box<dyn Trainer>, files: Vec<String>) -> Result<()> {
|
||||||
&mut self,
|
|
||||||
trainer: &Box<dyn Trainer>,
|
|
||||||
files: Vec<String>,
|
|
||||||
) -> Result<(), Box<dyn Error>> {
|
|
||||||
let results = files
|
let results = files
|
||||||
.par_iter()
|
.par_iter()
|
||||||
.map(|file| -> std::io::Result<HashMap<String, u32>> {
|
.map(|file| -> Result<HashMap<String, u32>> {
|
||||||
let mut words = HashMap::new();
|
let mut words = HashMap::new();
|
||||||
|
|
||||||
let file: std::fs::File = File::open(file)?;
|
let file: std::fs::File = File::open(file)?;
|
||||||
@ -277,8 +295,8 @@ impl Tokenizer {
|
|||||||
|
|
||||||
for line in file.lines() {
|
for line in file.lines() {
|
||||||
let line = line?;
|
let line = line?;
|
||||||
let normalized = self.normalize(line);
|
let normalized = self.normalize(line)?;
|
||||||
let pre_tokenized = self.pre_tokenize(&normalized);
|
let pre_tokenized = self.pre_tokenize(&normalized)?;
|
||||||
trainer.process_tokens(&mut words, pre_tokenized);
|
trainer.process_tokens(&mut words, pre_tokenized);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -302,30 +320,34 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// PreTokenization logic, handling the case where there is no PreTokenizer set
|
/// PreTokenization logic, handling the case where there is no PreTokenizer set
|
||||||
fn pre_tokenize(&self, sentence: &str) -> Vec<String> {
|
fn pre_tokenize(&self, sentence: &str) -> Result<Vec<String>> {
|
||||||
match &self.pre_tokenizer {
|
match &self.pre_tokenizer {
|
||||||
None => vec![sentence.to_owned()],
|
None => Ok(vec![sentence.to_owned()]),
|
||||||
Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence),
|
Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Normalization logic, go through all normalizers
|
/// Normalization logic, go through all normalizers
|
||||||
fn normalize(&self, sentence: String) -> String {
|
fn normalize(&self, sentence: String) -> Result<String> {
|
||||||
if let Some(normalizer) = &self.normalizer {
|
if let Some(normalizer) = &self.normalizer {
|
||||||
normalizer.normalize(sentence)
|
normalizer.normalize(sentence)
|
||||||
} else {
|
} else {
|
||||||
sentence.to_owned()
|
Ok(sentence.to_owned())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Post processing logic, handling the case where there is no PostProcessor set
|
/// Post processing logic, handling the case where there is no PostProcessor set
|
||||||
fn post_process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding {
|
fn post_process(
|
||||||
|
&self,
|
||||||
|
encoding: Encoding,
|
||||||
|
pair_encoding: Option<Encoding>,
|
||||||
|
) -> Result<Encoding> {
|
||||||
if let Some(processor) = &self.post_processor {
|
if let Some(processor) = &self.post_processor {
|
||||||
processor.process(encoding, pair_encoding)
|
processor.process(encoding, pair_encoding)
|
||||||
} else {
|
} else {
|
||||||
match pair_encoding {
|
match pair_encoding {
|
||||||
None => encoding,
|
None => Ok(encoding),
|
||||||
Some(pair) => Encoding {
|
Some(pair) => Ok(Encoding {
|
||||||
original: format!("{}{}", encoding.original, pair.original),
|
original: format!("{}{}", encoding.original, pair.original),
|
||||||
normalized: format!("{}{}", encoding.normalized, pair.normalized),
|
normalized: format!("{}{}", encoding.normalized, pair.normalized),
|
||||||
ids: [&encoding.ids[..], &pair.ids[..]].concat(),
|
ids: [&encoding.ids[..], &pair.ids[..]].concat(),
|
||||||
@ -345,7 +367,7 @@ impl Tokenizer {
|
|||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
]
|
]
|
||||||
.concat(),
|
.concat(),
|
||||||
},
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user