Improved Tokenizer interface

This commit is contained in:
Anthony MOI
2019-12-10 11:41:54 -05:00
parent 018f57f054
commit 132a0fc4b4
2 changed files with 148 additions and 44 deletions

View File

@ -6,7 +6,7 @@ use clap::{App, AppSettings, Arg, ArgMatches, SubCommand};
use std::io::{self, BufRead, Write}; use std::io::{self, BufRead, Write};
use tokenizers::models::bpe::{Error, BPE}; use tokenizers::models::bpe::{Error, BPE};
use tokenizers::pre_tokenizers::byte_level::ByteLevel; use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::{EncodeInput, Tokenizer};
fn shell(matches: &ArgMatches) -> Result<(), Error> { fn shell(matches: &ArgMatches) -> Result<(), Error> {
let vocab = matches let vocab = matches
@ -33,21 +33,12 @@ fn shell(matches: &ArgMatches) -> Result<(), Error> {
let buffer = buffer.trim_end(); let buffer = buffer.trim_end();
let timer = std::time::Instant::now(); let timer = std::time::Instant::now();
let encoded = tokenizer.encode(buffer); let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned()));
let elapsed = timer.elapsed(); let elapsed = timer.elapsed();
println!("\nInput:\t\t{}", buffer); println!("\nInput:\t\t{}", buffer);
println!( println!("Tokens:\t\t{:?}", encoded.get_tokens());
"Tokens:\t\t{:?}", println!("IDs:\t\t{:?}", encoded.get_ids());
encoded.iter().map(|t| &t.value).collect::<Vec<_>>() println!("Offsets:\t{:?}", encoded.get_offsets());
);
println!(
"IDs:\t\t{:?}",
encoded.iter().map(|t| t.id).collect::<Vec<_>>()
);
println!(
"Offsets:\t{:?}",
encoded.iter().map(|t| t.offsets).collect::<Vec<_>>()
);
println!("Tokenized in {:?}", elapsed); println!("Tokenized in {:?}", elapsed);
} }
} }

View File

@ -22,12 +22,12 @@ use std::{
/// A Normalizer takes care of pre-processing strings /// A Normalizer takes care of pre-processing strings
pub trait Normalizer { pub trait Normalizer {
// TODO: Use Cow here to avoid useless allocation if nothing is modified fn normalize(&self, s: String) -> String;
fn normalize(&self, s: &str) -> String;
} }
/// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model /// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model
pub trait PreTokenizer { pub trait PreTokenizer {
// TODO: Should return offsets with each substring
fn pre_tokenize(&self, s: &str) -> Vec<String>; fn pre_tokenize(&self, s: &str) -> Vec<String>;
} }
@ -42,7 +42,7 @@ pub trait Model {
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer. /// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
/// Truncating, Padding, etc... are PostProcessor steps /// Truncating, Padding, etc... are PostProcessor steps
pub trait PostProcessor { pub trait PostProcessor {
fn process(&self, tokens: Vec<Token>) -> Vec<Token>; fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding;
} }
/// A Decoder has the responsibility to merge the given Vec<String> in a String /// A Decoder has the responsibility to merge the given Vec<String> in a String
@ -57,13 +57,12 @@ pub trait Trainer: Sync {
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>); fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
} }
/// A Token represents the output of the Tokenizer /// A Token
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct Token { pub struct Token {
pub id: u32, pub id: u32,
pub value: String, pub value: String,
pub offsets: (usize, usize), pub offsets: (usize, usize),
// TODO: Find out the best way to define the customizable part (For post processing steps)
} }
impl Token { impl Token {
pub fn new(id: u32, value: String, offsets: (usize, usize)) -> Self { pub fn new(id: u32, value: String, offsets: (usize, usize)) -> Self {
@ -71,16 +70,57 @@ impl Token {
} }
} }
/// The Encoding struct represents the output of the Tokenizer
#[derive(Default)]
pub struct Encoding {
original: String,
normalized: String,
ids: Vec<u32>,
type_ids: Vec<u32>,
tokens: Vec<String>,
offsets: Vec<(usize, usize)>,
}
impl Encoding {
pub fn get_original(&self) -> &str {
&self.original
}
pub fn get_normalized(&self) -> &str {
&self.normalized
}
pub fn get_tokens(&self) -> &[String] {
&self.tokens[..]
}
pub fn get_ids(&self) -> &[u32] {
&self.ids
}
pub fn get_type_ids(&self) -> &[u32] {
&self.type_ids
}
pub fn get_offsets(&self) -> &[(usize, usize)] {
&self.offsets
}
}
pub enum EncodeInput {
Single(String),
Dual(String, String),
}
/// ///
/// ## Tokenizer /// ## Tokenizer
/// ///
/// A Tokenizer is capable of encoding/decoding any text /// A Tokenizer is capable of encoding/decoding any text
/// ///
pub struct Tokenizer { pub struct Tokenizer {
normalizers: Vec<Box<dyn Normalizer + Sync>>, normalizer: Option<Box<dyn Normalizer + Sync>>,
pre_tokenizer: Option<Box<dyn PreTokenizer + Sync>>, pre_tokenizer: Option<Box<dyn PreTokenizer + Sync>>,
model: Box<dyn Model + Sync>, model: Box<dyn Model + Sync>,
post_processors: Vec<Box<dyn PostProcessor + Sync>>, post_processor: Option<Box<dyn PostProcessor + Sync>>,
decoder: Option<Box<dyn Decoder + Sync>>, decoder: Option<Box<dyn Decoder + Sync>>,
} }
@ -88,17 +128,17 @@ impl Tokenizer {
/// Instanciate a new Tokenizer, with the given Model /// Instanciate a new Tokenizer, with the given Model
pub fn new(model: Box<dyn Model + Sync>) -> Self { pub fn new(model: Box<dyn Model + Sync>) -> Self {
Tokenizer { Tokenizer {
normalizers: vec![], normalizer: None,
pre_tokenizer: None, pre_tokenizer: None,
model, model,
post_processors: vec![], post_processor: None,
decoder: None, decoder: None,
} }
} }
/// Set the normalizers /// Set the normalizers
pub fn with_normalizers(&mut self, normalizers: Vec<Box<dyn Normalizer + Sync>>) -> &Self { pub fn with_normalizers(&mut self, normalizer: Box<dyn Normalizer + Sync>) -> &Self {
self.normalizers = normalizers; self.normalizer = Some(normalizer);
self self
} }
@ -108,12 +148,9 @@ impl Tokenizer {
self self
} }
/// Set the post processors /// Set the post processor
pub fn with_post_processors( pub fn with_post_processor(&mut self, post_processor: Box<dyn PostProcessor + Sync>) -> &Self {
&mut self, self.post_processor = Some(post_processor);
post_processors: Vec<Box<dyn PostProcessor + Sync>>,
) -> &Self {
self.post_processors = post_processors;
self self
} }
@ -140,18 +177,62 @@ impl Tokenizer {
} }
/// Encode the given sentence /// Encode the given sentence
pub fn encode(&self, sentence: &str) -> Vec<Token> { pub fn encode(&self, input: EncodeInput) -> Encoding {
let normalized = self.normalize(sentence); let generate_output = move |sentence: String, type_id: u32| -> Encoding {
let pre_tokenized = self.pre_tokenize(&normalized); // 1. Normalization
// TODO: Make sure we have the offsets update necessary to go from the original text to
// the normalized one
let original = sentence.clone();
let normalized = self.normalize(sentence);
self.model.tokenize(pre_tokenized) // 2. Pre tokenization
let pre_tokenized = self.pre_tokenize(&normalized);
// 3. Model
let output = self.model.tokenize(pre_tokenized);
let length = output.len();
let (ids, tokens, offsets) = output.into_iter().fold(
(
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
),
|(mut ids, mut tokens, mut offsets), t| {
ids.push(t.id);
tokens.push(t.value);
offsets.push(t.offsets);
(ids, tokens, offsets)
},
);
Encoding {
original,
normalized,
ids,
type_ids: vec![type_id; length],
tokens,
offsets,
}
};
let (sentence, pair) = match input {
EncodeInput::Single(s1) => (s1, None),
EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
};
let encoding = generate_output(sentence, 0);
let pair_encoding = pair.map(|pair| generate_output(pair, 1));
// 4. Post processing
self.post_process(encoding, pair_encoding)
} }
/// Encode all the sentences in parallel, using multiple threads /// Encode all the sentences in parallel, using multiple threads
pub fn encode_batch(&self, sentences: Vec<&str>) -> Vec<Vec<Token>> { pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Vec<Encoding> {
sentences inputs
.par_iter() .into_par_iter()
.map(|sentence| self.encode(sentence)) .map(|input| self.encode(input))
.collect() .collect()
} }
@ -190,7 +271,7 @@ impl Tokenizer {
for line in file.lines() { for line in file.lines() {
let line = line?; let line = line?;
let normalized = self.normalize(&line); let normalized = self.normalize(line);
let pre_tokenized = self.pre_tokenize(&normalized); let pre_tokenized = self.pre_tokenize(&normalized);
trainer.process_tokens(&mut words, pre_tokenized); trainer.process_tokens(&mut words, pre_tokenized);
} }
@ -223,11 +304,43 @@ impl Tokenizer {
} }
/// Normalization logic, go through all normalizers /// Normalization logic, go through all normalizers
fn normalize(&self, sentence: &str) -> String { fn normalize(&self, sentence: String) -> String {
if self.normalizers.len() == 0 { if let Some(normalizer) = &self.normalizer {
sentence.to_owned() normalizer.normalize(sentence)
} else { } else {
unimplemented!("Normalization has not been implemented yet") sentence.to_owned()
}
}
/// Post processing logic, handling the case where there is no PostProcessor set
fn post_process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding {
if let Some(processor) = &self.post_processor {
processor.process(encoding, pair_encoding)
} else {
match pair_encoding {
None => encoding,
Some(pair) => Encoding {
original: format!("{}{}", encoding.original, pair.original),
normalized: format!("{}{}", encoding.normalized, pair.normalized),
ids: [&encoding.ids[..], &pair.ids[..]].concat(),
type_ids: [&encoding.type_ids[..], &pair.type_ids[..]].concat(),
tokens: [&encoding.tokens[..], &pair.tokens[..]].concat(),
offsets: [
&encoding.offsets[..],
&pair
.offsets
.into_iter()
.map(|(start, end)| {
(
start + encoding.original.len(),
end + encoding.original.len(),
)
})
.collect::<Vec<_>>(),
]
.concat(),
},
}
} }
} }
} }