Improved Tokenizer interface

This commit is contained in:
Anthony MOI
2019-12-10 11:41:54 -05:00
parent 018f57f054
commit 132a0fc4b4
2 changed files with 148 additions and 44 deletions

View File

@ -6,7 +6,7 @@ use clap::{App, AppSettings, Arg, ArgMatches, SubCommand};
use std::io::{self, BufRead, Write};
use tokenizers::models::bpe::{Error, BPE};
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::tokenizer::{EncodeInput, Tokenizer};
fn shell(matches: &ArgMatches) -> Result<(), Error> {
let vocab = matches
@ -33,21 +33,12 @@ fn shell(matches: &ArgMatches) -> Result<(), Error> {
let buffer = buffer.trim_end();
let timer = std::time::Instant::now();
let encoded = tokenizer.encode(buffer);
let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned()));
let elapsed = timer.elapsed();
println!("\nInput:\t\t{}", buffer);
println!(
"Tokens:\t\t{:?}",
encoded.iter().map(|t| &t.value).collect::<Vec<_>>()
);
println!(
"IDs:\t\t{:?}",
encoded.iter().map(|t| t.id).collect::<Vec<_>>()
);
println!(
"Offsets:\t{:?}",
encoded.iter().map(|t| t.offsets).collect::<Vec<_>>()
);
println!("Tokens:\t\t{:?}", encoded.get_tokens());
println!("IDs:\t\t{:?}", encoded.get_ids());
println!("Offsets:\t{:?}", encoded.get_offsets());
println!("Tokenized in {:?}", elapsed);
}
}

View File

@ -22,12 +22,12 @@ use std::{
/// A Normalizer takes care of pre-processing strings
pub trait Normalizer {
// TODO: Use Cow here to avoid useless allocation if nothing is modified
fn normalize(&self, s: &str) -> String;
fn normalize(&self, s: String) -> String;
}
/// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model
pub trait PreTokenizer {
// TODO: Should return offsets with each substring
fn pre_tokenize(&self, s: &str) -> Vec<String>;
}
@ -42,7 +42,7 @@ pub trait Model {
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
/// Truncating, Padding, etc... are PostProcessor steps
pub trait PostProcessor {
fn process(&self, tokens: Vec<Token>) -> Vec<Token>;
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding;
}
/// A Decoder has the responsibility to merge the given Vec<String> in a String
@ -57,13 +57,12 @@ pub trait Trainer: Sync {
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
}
/// A Token represents the output of the Tokenizer
/// A Token
#[derive(Debug, PartialEq)]
pub struct Token {
pub id: u32,
pub value: String,
pub offsets: (usize, usize),
// TODO: Find out the best way to define the customizable part (For post processing steps)
}
impl Token {
pub fn new(id: u32, value: String, offsets: (usize, usize)) -> Self {
@ -71,16 +70,57 @@ impl Token {
}
}
/// The Encoding struct represents the output of the Tokenizer
#[derive(Default)]
pub struct Encoding {
original: String,
normalized: String,
ids: Vec<u32>,
type_ids: Vec<u32>,
tokens: Vec<String>,
offsets: Vec<(usize, usize)>,
}
impl Encoding {
pub fn get_original(&self) -> &str {
&self.original
}
pub fn get_normalized(&self) -> &str {
&self.normalized
}
pub fn get_tokens(&self) -> &[String] {
&self.tokens[..]
}
pub fn get_ids(&self) -> &[u32] {
&self.ids
}
pub fn get_type_ids(&self) -> &[u32] {
&self.type_ids
}
pub fn get_offsets(&self) -> &[(usize, usize)] {
&self.offsets
}
}
pub enum EncodeInput {
Single(String),
Dual(String, String),
}
///
/// ## Tokenizer
///
/// A Tokenizer is capable of encoding/decoding any text
///
pub struct Tokenizer {
normalizers: Vec<Box<dyn Normalizer + Sync>>,
normalizer: Option<Box<dyn Normalizer + Sync>>,
pre_tokenizer: Option<Box<dyn PreTokenizer + Sync>>,
model: Box<dyn Model + Sync>,
post_processors: Vec<Box<dyn PostProcessor + Sync>>,
post_processor: Option<Box<dyn PostProcessor + Sync>>,
decoder: Option<Box<dyn Decoder + Sync>>,
}
@ -88,17 +128,17 @@ impl Tokenizer {
/// Instanciate a new Tokenizer, with the given Model
pub fn new(model: Box<dyn Model + Sync>) -> Self {
Tokenizer {
normalizers: vec![],
normalizer: None,
pre_tokenizer: None,
model,
post_processors: vec![],
post_processor: None,
decoder: None,
}
}
/// Set the normalizers
pub fn with_normalizers(&mut self, normalizers: Vec<Box<dyn Normalizer + Sync>>) -> &Self {
self.normalizers = normalizers;
pub fn with_normalizers(&mut self, normalizer: Box<dyn Normalizer + Sync>) -> &Self {
self.normalizer = Some(normalizer);
self
}
@ -108,12 +148,9 @@ impl Tokenizer {
self
}
/// Set the post processors
pub fn with_post_processors(
&mut self,
post_processors: Vec<Box<dyn PostProcessor + Sync>>,
) -> &Self {
self.post_processors = post_processors;
/// Set the post processor
pub fn with_post_processor(&mut self, post_processor: Box<dyn PostProcessor + Sync>) -> &Self {
self.post_processor = Some(post_processor);
self
}
@ -140,18 +177,62 @@ impl Tokenizer {
}
/// Encode the given sentence
pub fn encode(&self, sentence: &str) -> Vec<Token> {
let normalized = self.normalize(sentence);
let pre_tokenized = self.pre_tokenize(&normalized);
pub fn encode(&self, input: EncodeInput) -> Encoding {
let generate_output = move |sentence: String, type_id: u32| -> Encoding {
// 1. Normalization
// TODO: Make sure we have the offsets update necessary to go from the original text to
// the normalized one
let original = sentence.clone();
let normalized = self.normalize(sentence);
self.model.tokenize(pre_tokenized)
// 2. Pre tokenization
let pre_tokenized = self.pre_tokenize(&normalized);
// 3. Model
let output = self.model.tokenize(pre_tokenized);
let length = output.len();
let (ids, tokens, offsets) = output.into_iter().fold(
(
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
),
|(mut ids, mut tokens, mut offsets), t| {
ids.push(t.id);
tokens.push(t.value);
offsets.push(t.offsets);
(ids, tokens, offsets)
},
);
Encoding {
original,
normalized,
ids,
type_ids: vec![type_id; length],
tokens,
offsets,
}
};
let (sentence, pair) = match input {
EncodeInput::Single(s1) => (s1, None),
EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
};
let encoding = generate_output(sentence, 0);
let pair_encoding = pair.map(|pair| generate_output(pair, 1));
// 4. Post processing
self.post_process(encoding, pair_encoding)
}
/// Encode all the sentences in parallel, using multiple threads
pub fn encode_batch(&self, sentences: Vec<&str>) -> Vec<Vec<Token>> {
sentences
.par_iter()
.map(|sentence| self.encode(sentence))
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Vec<Encoding> {
inputs
.into_par_iter()
.map(|input| self.encode(input))
.collect()
}
@ -190,7 +271,7 @@ impl Tokenizer {
for line in file.lines() {
let line = line?;
let normalized = self.normalize(&line);
let normalized = self.normalize(line);
let pre_tokenized = self.pre_tokenize(&normalized);
trainer.process_tokens(&mut words, pre_tokenized);
}
@ -223,11 +304,43 @@ impl Tokenizer {
}
/// Normalization logic, go through all normalizers
fn normalize(&self, sentence: &str) -> String {
if self.normalizers.len() == 0 {
sentence.to_owned()
fn normalize(&self, sentence: String) -> String {
if let Some(normalizer) = &self.normalizer {
normalizer.normalize(sentence)
} else {
unimplemented!("Normalization has not been implemented yet")
sentence.to_owned()
}
}
/// Post processing logic, handling the case where there is no PostProcessor set
fn post_process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding {
if let Some(processor) = &self.post_processor {
processor.process(encoding, pair_encoding)
} else {
match pair_encoding {
None => encoding,
Some(pair) => Encoding {
original: format!("{}{}", encoding.original, pair.original),
normalized: format!("{}{}", encoding.normalized, pair.normalized),
ids: [&encoding.ids[..], &pair.ids[..]].concat(),
type_ids: [&encoding.type_ids[..], &pair.type_ids[..]].concat(),
tokens: [&encoding.tokens[..], &pair.tokens[..]].concat(),
offsets: [
&encoding.offsets[..],
&pair
.offsets
.into_iter()
.map(|(start, end)| {
(
start + encoding.original.len(),
end + encoding.original.len(),
)
})
.collect::<Vec<_>>(),
]
.concat(),
},
}
}
}
}