mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Improved Tokenizer interface
This commit is contained in:
@ -6,7 +6,7 @@ use clap::{App, AppSettings, Arg, ArgMatches, SubCommand};
|
|||||||
use std::io::{self, BufRead, Write};
|
use std::io::{self, BufRead, Write};
|
||||||
use tokenizers::models::bpe::{Error, BPE};
|
use tokenizers::models::bpe::{Error, BPE};
|
||||||
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::{EncodeInput, Tokenizer};
|
||||||
|
|
||||||
fn shell(matches: &ArgMatches) -> Result<(), Error> {
|
fn shell(matches: &ArgMatches) -> Result<(), Error> {
|
||||||
let vocab = matches
|
let vocab = matches
|
||||||
@ -33,21 +33,12 @@ fn shell(matches: &ArgMatches) -> Result<(), Error> {
|
|||||||
let buffer = buffer.trim_end();
|
let buffer = buffer.trim_end();
|
||||||
|
|
||||||
let timer = std::time::Instant::now();
|
let timer = std::time::Instant::now();
|
||||||
let encoded = tokenizer.encode(buffer);
|
let encoded = tokenizer.encode(EncodeInput::Single(buffer.to_owned()));
|
||||||
let elapsed = timer.elapsed();
|
let elapsed = timer.elapsed();
|
||||||
println!("\nInput:\t\t{}", buffer);
|
println!("\nInput:\t\t{}", buffer);
|
||||||
println!(
|
println!("Tokens:\t\t{:?}", encoded.get_tokens());
|
||||||
"Tokens:\t\t{:?}",
|
println!("IDs:\t\t{:?}", encoded.get_ids());
|
||||||
encoded.iter().map(|t| &t.value).collect::<Vec<_>>()
|
println!("Offsets:\t{:?}", encoded.get_offsets());
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"IDs:\t\t{:?}",
|
|
||||||
encoded.iter().map(|t| t.id).collect::<Vec<_>>()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"Offsets:\t{:?}",
|
|
||||||
encoded.iter().map(|t| t.offsets).collect::<Vec<_>>()
|
|
||||||
);
|
|
||||||
println!("Tokenized in {:?}", elapsed);
|
println!("Tokenized in {:?}", elapsed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -22,12 +22,12 @@ use std::{
|
|||||||
|
|
||||||
/// A Normalizer takes care of pre-processing strings
|
/// A Normalizer takes care of pre-processing strings
|
||||||
pub trait Normalizer {
|
pub trait Normalizer {
|
||||||
// TODO: Use Cow here to avoid useless allocation if nothing is modified
|
fn normalize(&self, s: String) -> String;
|
||||||
fn normalize(&self, s: &str) -> String;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model
|
/// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model
|
||||||
pub trait PreTokenizer {
|
pub trait PreTokenizer {
|
||||||
|
// TODO: Should return offsets with each substring
|
||||||
fn pre_tokenize(&self, s: &str) -> Vec<String>;
|
fn pre_tokenize(&self, s: &str) -> Vec<String>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ pub trait Model {
|
|||||||
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
|
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
|
||||||
/// Truncating, Padding, etc... are PostProcessor steps
|
/// Truncating, Padding, etc... are PostProcessor steps
|
||||||
pub trait PostProcessor {
|
pub trait PostProcessor {
|
||||||
fn process(&self, tokens: Vec<Token>) -> Vec<Token>;
|
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A Decoder has the responsibility to merge the given Vec<String> in a String
|
/// A Decoder has the responsibility to merge the given Vec<String> in a String
|
||||||
@ -57,13 +57,12 @@ pub trait Trainer: Sync {
|
|||||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
|
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A Token represents the output of the Tokenizer
|
/// A Token
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct Token {
|
pub struct Token {
|
||||||
pub id: u32,
|
pub id: u32,
|
||||||
pub value: String,
|
pub value: String,
|
||||||
pub offsets: (usize, usize),
|
pub offsets: (usize, usize),
|
||||||
// TODO: Find out the best way to define the customizable part (For post processing steps)
|
|
||||||
}
|
}
|
||||||
impl Token {
|
impl Token {
|
||||||
pub fn new(id: u32, value: String, offsets: (usize, usize)) -> Self {
|
pub fn new(id: u32, value: String, offsets: (usize, usize)) -> Self {
|
||||||
@ -71,16 +70,57 @@ impl Token {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The Encoding struct represents the output of the Tokenizer
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct Encoding {
|
||||||
|
original: String,
|
||||||
|
normalized: String,
|
||||||
|
ids: Vec<u32>,
|
||||||
|
type_ids: Vec<u32>,
|
||||||
|
tokens: Vec<String>,
|
||||||
|
offsets: Vec<(usize, usize)>,
|
||||||
|
}
|
||||||
|
impl Encoding {
|
||||||
|
pub fn get_original(&self) -> &str {
|
||||||
|
&self.original
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_normalized(&self) -> &str {
|
||||||
|
&self.normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_tokens(&self) -> &[String] {
|
||||||
|
&self.tokens[..]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_ids(&self) -> &[u32] {
|
||||||
|
&self.ids
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_type_ids(&self) -> &[u32] {
|
||||||
|
&self.type_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_offsets(&self) -> &[(usize, usize)] {
|
||||||
|
&self.offsets
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum EncodeInput {
|
||||||
|
Single(String),
|
||||||
|
Dual(String, String),
|
||||||
|
}
|
||||||
|
|
||||||
///
|
///
|
||||||
/// ## Tokenizer
|
/// ## Tokenizer
|
||||||
///
|
///
|
||||||
/// A Tokenizer is capable of encoding/decoding any text
|
/// A Tokenizer is capable of encoding/decoding any text
|
||||||
///
|
///
|
||||||
pub struct Tokenizer {
|
pub struct Tokenizer {
|
||||||
normalizers: Vec<Box<dyn Normalizer + Sync>>,
|
normalizer: Option<Box<dyn Normalizer + Sync>>,
|
||||||
pre_tokenizer: Option<Box<dyn PreTokenizer + Sync>>,
|
pre_tokenizer: Option<Box<dyn PreTokenizer + Sync>>,
|
||||||
model: Box<dyn Model + Sync>,
|
model: Box<dyn Model + Sync>,
|
||||||
post_processors: Vec<Box<dyn PostProcessor + Sync>>,
|
post_processor: Option<Box<dyn PostProcessor + Sync>>,
|
||||||
decoder: Option<Box<dyn Decoder + Sync>>,
|
decoder: Option<Box<dyn Decoder + Sync>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,17 +128,17 @@ impl Tokenizer {
|
|||||||
/// Instanciate a new Tokenizer, with the given Model
|
/// Instanciate a new Tokenizer, with the given Model
|
||||||
pub fn new(model: Box<dyn Model + Sync>) -> Self {
|
pub fn new(model: Box<dyn Model + Sync>) -> Self {
|
||||||
Tokenizer {
|
Tokenizer {
|
||||||
normalizers: vec![],
|
normalizer: None,
|
||||||
pre_tokenizer: None,
|
pre_tokenizer: None,
|
||||||
model,
|
model,
|
||||||
post_processors: vec![],
|
post_processor: None,
|
||||||
decoder: None,
|
decoder: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the normalizers
|
/// Set the normalizers
|
||||||
pub fn with_normalizers(&mut self, normalizers: Vec<Box<dyn Normalizer + Sync>>) -> &Self {
|
pub fn with_normalizers(&mut self, normalizer: Box<dyn Normalizer + Sync>) -> &Self {
|
||||||
self.normalizers = normalizers;
|
self.normalizer = Some(normalizer);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,12 +148,9 @@ impl Tokenizer {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the post processors
|
/// Set the post processor
|
||||||
pub fn with_post_processors(
|
pub fn with_post_processor(&mut self, post_processor: Box<dyn PostProcessor + Sync>) -> &Self {
|
||||||
&mut self,
|
self.post_processor = Some(post_processor);
|
||||||
post_processors: Vec<Box<dyn PostProcessor + Sync>>,
|
|
||||||
) -> &Self {
|
|
||||||
self.post_processors = post_processors;
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,18 +177,62 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Encode the given sentence
|
/// Encode the given sentence
|
||||||
pub fn encode(&self, sentence: &str) -> Vec<Token> {
|
pub fn encode(&self, input: EncodeInput) -> Encoding {
|
||||||
let normalized = self.normalize(sentence);
|
let generate_output = move |sentence: String, type_id: u32| -> Encoding {
|
||||||
let pre_tokenized = self.pre_tokenize(&normalized);
|
// 1. Normalization
|
||||||
|
// TODO: Make sure we have the offsets update necessary to go from the original text to
|
||||||
|
// the normalized one
|
||||||
|
let original = sentence.clone();
|
||||||
|
let normalized = self.normalize(sentence);
|
||||||
|
|
||||||
self.model.tokenize(pre_tokenized)
|
// 2. Pre tokenization
|
||||||
|
let pre_tokenized = self.pre_tokenize(&normalized);
|
||||||
|
|
||||||
|
// 3. Model
|
||||||
|
let output = self.model.tokenize(pre_tokenized);
|
||||||
|
let length = output.len();
|
||||||
|
|
||||||
|
let (ids, tokens, offsets) = output.into_iter().fold(
|
||||||
|
(
|
||||||
|
Vec::with_capacity(length),
|
||||||
|
Vec::with_capacity(length),
|
||||||
|
Vec::with_capacity(length),
|
||||||
|
),
|
||||||
|
|(mut ids, mut tokens, mut offsets), t| {
|
||||||
|
ids.push(t.id);
|
||||||
|
tokens.push(t.value);
|
||||||
|
offsets.push(t.offsets);
|
||||||
|
(ids, tokens, offsets)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
Encoding {
|
||||||
|
original,
|
||||||
|
normalized,
|
||||||
|
ids,
|
||||||
|
type_ids: vec![type_id; length],
|
||||||
|
tokens,
|
||||||
|
offsets,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (sentence, pair) = match input {
|
||||||
|
EncodeInput::Single(s1) => (s1, None),
|
||||||
|
EncodeInput::Dual(s1, s2) => (s1, Some(s2)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoding = generate_output(sentence, 0);
|
||||||
|
let pair_encoding = pair.map(|pair| generate_output(pair, 1));
|
||||||
|
|
||||||
|
// 4. Post processing
|
||||||
|
self.post_process(encoding, pair_encoding)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encode all the sentences in parallel, using multiple threads
|
/// Encode all the sentences in parallel, using multiple threads
|
||||||
pub fn encode_batch(&self, sentences: Vec<&str>) -> Vec<Vec<Token>> {
|
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Vec<Encoding> {
|
||||||
sentences
|
inputs
|
||||||
.par_iter()
|
.into_par_iter()
|
||||||
.map(|sentence| self.encode(sentence))
|
.map(|input| self.encode(input))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,7 +271,7 @@ impl Tokenizer {
|
|||||||
|
|
||||||
for line in file.lines() {
|
for line in file.lines() {
|
||||||
let line = line?;
|
let line = line?;
|
||||||
let normalized = self.normalize(&line);
|
let normalized = self.normalize(line);
|
||||||
let pre_tokenized = self.pre_tokenize(&normalized);
|
let pre_tokenized = self.pre_tokenize(&normalized);
|
||||||
trainer.process_tokens(&mut words, pre_tokenized);
|
trainer.process_tokens(&mut words, pre_tokenized);
|
||||||
}
|
}
|
||||||
@ -223,11 +304,43 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Normalization logic, go through all normalizers
|
/// Normalization logic, go through all normalizers
|
||||||
fn normalize(&self, sentence: &str) -> String {
|
fn normalize(&self, sentence: String) -> String {
|
||||||
if self.normalizers.len() == 0 {
|
if let Some(normalizer) = &self.normalizer {
|
||||||
sentence.to_owned()
|
normalizer.normalize(sentence)
|
||||||
} else {
|
} else {
|
||||||
unimplemented!("Normalization has not been implemented yet")
|
sentence.to_owned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Post processing logic, handling the case where there is no PostProcessor set
|
||||||
|
fn post_process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Encoding {
|
||||||
|
if let Some(processor) = &self.post_processor {
|
||||||
|
processor.process(encoding, pair_encoding)
|
||||||
|
} else {
|
||||||
|
match pair_encoding {
|
||||||
|
None => encoding,
|
||||||
|
Some(pair) => Encoding {
|
||||||
|
original: format!("{}{}", encoding.original, pair.original),
|
||||||
|
normalized: format!("{}{}", encoding.normalized, pair.normalized),
|
||||||
|
ids: [&encoding.ids[..], &pair.ids[..]].concat(),
|
||||||
|
type_ids: [&encoding.type_ids[..], &pair.type_ids[..]].concat(),
|
||||||
|
tokens: [&encoding.tokens[..], &pair.tokens[..]].concat(),
|
||||||
|
offsets: [
|
||||||
|
&encoding.offsets[..],
|
||||||
|
&pair
|
||||||
|
.offsets
|
||||||
|
.into_iter()
|
||||||
|
.map(|(start, end)| {
|
||||||
|
(
|
||||||
|
start + encoding.original.len(),
|
||||||
|
end + encoding.original.len(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
]
|
||||||
|
.concat(),
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user