Make Tokenizer generic over pre-tokenizers.

This commit is contained in:
Sebastian Pütz
2020-07-25 15:57:51 +02:00
committed by Anthony MOI
parent 08b8c48127
commit bcc54a2ea1
6 changed files with 62 additions and 46 deletions

View File

@@ -11,27 +11,28 @@ use tokenizers::normalizers::NormalizerWrapper;
use tokenizers::pre_tokenizers::byte_level::ByteLevel; use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::pre_tokenizers::whitespace::Whitespace; use tokenizers::pre_tokenizers::whitespace::Whitespace;
use tokenizers::tokenizer::{AddedToken, EncodeInput, Tokenizer, Trainer}; use tokenizers::tokenizer::{AddedToken, EncodeInput, Tokenizer, Trainer};
use tokenizers::{Model, Normalizer}; use tokenizers::{Model, Normalizer, PreTokenizer};
static BATCH_SIZE: usize = 1_000; static BATCH_SIZE: usize = 1_000;
fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer<BPE, NormalizerWrapper> { fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer<BPE, NormalizerWrapper, ByteLevel> {
let mut tokenizer = Tokenizer::new(bpe); let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default())); tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_decoder(Box::new(ByteLevel::default())); tokenizer.with_decoder(Box::new(ByteLevel::default()));
tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]); tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]);
tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]); tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]);
tokenizer tokenizer
} }
fn iter_bench_encode<M, N>( fn iter_bench_encode<M, N, PT>(
iters: u64, iters: u64,
tokenizer: &Tokenizer<M, N>, tokenizer: &Tokenizer<M, N, PT>,
lines: &[EncodeInput], lines: &[EncodeInput],
) -> Duration ) -> Duration
where where
M: Model, M: Model,
N: Normalizer, N: Normalizer,
PT: PreTokenizer,
{ {
let mut duration = Duration::new(0, 0); let mut duration = Duration::new(0, 0);
let mut line_index: usize = 0; let mut line_index: usize = 0;
@@ -47,14 +48,15 @@ where
duration duration
} }
fn iter_bench_encode_batch<M, N>( fn iter_bench_encode_batch<M, N, PT>(
iters: u64, iters: u64,
tokenizer: &Tokenizer<M, N>, tokenizer: &Tokenizer<M, N, PT>,
batches: &[Vec<EncodeInput>], batches: &[Vec<EncodeInput>],
) -> Duration ) -> Duration
where where
M: Model, M: Model,
N: Normalizer, N: Normalizer,
PT: PreTokenizer,
{ {
let mut duration = Duration::new(0, 0); let mut duration = Duration::new(0, 0);
let mut batch_index: usize = 0; let mut batch_index: usize = 0;
@@ -109,12 +111,17 @@ fn bench_gpt2(c: &mut Criterion) {
}); });
} }
fn iter_bench_train<T: Trainer<Model = BPE>>( fn iter_bench_train<T, M, PT>(
iters: u64, iters: u64,
mut tokenizer: Tokenizer<BPE, NormalizerWrapper>, mut tokenizer: Tokenizer<M, NormalizerWrapper, PT>,
trainer: &T, trainer: &T,
files: Vec<String>, files: Vec<String>,
) -> Duration { ) -> Duration
where
M: Model,
PT: PreTokenizer,
T: Trainer<Model = M>,
{
let mut duration = Duration::new(0, 0); let mut duration = Duration::new(0, 0);
for _i in 0..iters { for _i in 0..iters {
let start = Instant::now(); let start = Instant::now();
@@ -129,7 +136,7 @@ fn bench_train(c: &mut Criterion) {
c.bench_function("BPE Train vocabulary (small)", |b| { c.bench_function("BPE Train vocabulary (small)", |b| {
b.iter_custom(|iters| { b.iter_custom(|iters| {
let mut tokenizer = Tokenizer::new(BPE::default()); let mut tokenizer = Tokenizer::new(BPE::default());
tokenizer.with_pre_tokenizer(Box::new(Whitespace::default())); tokenizer.with_pre_tokenizer(Whitespace::default());
iter_bench_train( iter_bench_train(
iters, iters,
tokenizer, tokenizer,
@@ -142,7 +149,7 @@ fn bench_train(c: &mut Criterion) {
c.bench_function("BPE Train vocabulary (big)", |b| { c.bench_function("BPE Train vocabulary (big)", |b| {
b.iter_custom(|iters| { b.iter_custom(|iters| {
let mut tokenizer = Tokenizer::new(BPE::default()); let mut tokenizer = Tokenizer::new(BPE::default());
tokenizer.with_pre_tokenizer(Box::new(Whitespace::default())); tokenizer.with_pre_tokenizer(Whitespace::default());
iter_bench_train(iters, tokenizer, &trainer, vec!["data/big.txt".to_string()]) iter_bench_train(iters, tokenizer, &trainer, vec!["data/big.txt".to_string()])
}) })
}); });

View File

@@ -18,8 +18,8 @@ fn shell(matches: &ArgMatches) -> Result<()> {
.expect("Must give a merges.txt file"); .expect("Must give a merges.txt file");
let bpe = BPE::from_files(vocab, merges).build()?; let bpe = BPE::from_files(vocab, merges).build()?;
let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(bpe); let mut tokenizer = Tokenizer::<_, NormalizerWrapper, ByteLevel>::new(bpe);
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default())); tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_decoder(Box::new(ByteLevel::default())); tokenizer.with_decoder(Box::new(ByteLevel::default()));
tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]); tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]);

View File

@@ -28,6 +28,7 @@
//! use tokenizers::tokenizer::{Result, Tokenizer, EncodeInput}; //! use tokenizers::tokenizer::{Result, Tokenizer, EncodeInput};
//! use tokenizers::models::bpe::BPE; //! use tokenizers::models::bpe::BPE;
//! use tokenizers::normalizers::NormalizerWrapper; //! use tokenizers::normalizers::NormalizerWrapper;
//! use tokenizers::pre_tokenizers::PreTokenizerWrapper;
//! //!
//! fn main() -> Result<()> { //! fn main() -> Result<()> {
//! let bpe_builder = BPE::from_files("./path/to/vocab.json", "./path/to/merges.txt"); //! let bpe_builder = BPE::from_files("./path/to/vocab.json", "./path/to/merges.txt");
@@ -36,7 +37,7 @@
//! .unk_token("[UNK]".into()) //! .unk_token("[UNK]".into())
//! .build()?; //! .build()?;
//! //!
//! let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(bpe); //! let mut tokenizer = Tokenizer::<_, NormalizerWrapper, PreTokenizerWrapper>::new(bpe);
//! //!
//! let encoding = tokenizer.encode("Hey there!", false)?; //! let encoding = tokenizer.encode("Hey there!", false)?;
//! println!("{:?}", encoding.get_tokens()); //! println!("{:?}", encoding.get_tokens());

View File

@@ -213,10 +213,10 @@ impl fmt::Display for BuilderError {
/// Builder for Tokenizer structs. /// Builder for Tokenizer structs.
/// ///
/// `build()` fails if the `model` is missing. /// `build()` fails if the `model` is missing.
pub struct TokenizerBuilder<M, N> { pub struct TokenizerBuilder<M, N, PT> {
model: Option<M>, model: Option<M>,
normalizer: Option<N>, normalizer: Option<N>,
pre_tokenizer: Option<Box<dyn PreTokenizer>>, pre_tokenizer: Option<PT>,
post_processor: Option<Box<dyn PostProcessor>>, post_processor: Option<Box<dyn PostProcessor>>,
decoder: Option<Box<dyn Decoder>>, decoder: Option<Box<dyn Decoder>>,
@@ -226,20 +226,22 @@ pub struct TokenizerBuilder<M, N> {
padding: Option<PaddingParams>, padding: Option<PaddingParams>,
} }
impl<M, N> Default for TokenizerBuilder<M, N> impl<M, N, PT> Default for TokenizerBuilder<M, N, PT>
where where
M: Model, M: Model,
N: Normalizer, N: Normalizer,
PT: PreTokenizer,
{ {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
} }
impl<M, N> TokenizerBuilder<M, N> impl<M, N, PT> TokenizerBuilder<M, N, PT>
where where
M: Model, M: Model,
N: Normalizer, N: Normalizer,
PT: PreTokenizer,
{ {
/// Get an empty TokenizerBuilder. /// Get an empty TokenizerBuilder.
pub fn new() -> Self { pub fn new() -> Self {
@@ -258,7 +260,7 @@ where
/// Convert the TokenizerBuilder to a Tokenizer. /// Convert the TokenizerBuilder to a Tokenizer.
/// ///
/// Conversion fails if the `model` is missing. /// Conversion fails if the `model` is missing.
pub fn build(self) -> Result<Tokenizer<M, N>> { pub fn build(self) -> Result<Tokenizer<M, N, PT>> {
let model = self let model = self
.model .model
.ok_or_else(|| Box::new(BuilderError("Model missing.".into())))?; .ok_or_else(|| Box::new(BuilderError("Model missing.".into())))?;
@@ -288,7 +290,7 @@ where
} }
/// Set the pretokenizer. /// Set the pretokenizer.
pub fn with_pretokenizer(mut self, pretokenizer: Option<Box<dyn PreTokenizer>>) -> Self { pub fn with_pretokenizer(mut self, pretokenizer: Option<PT>) -> Self {
self.pre_tokenizer = pretokenizer; self.pre_tokenizer = pretokenizer;
self self
} }
@@ -319,10 +321,10 @@ where
} }
/// A `Tokenizer` is capable of encoding/decoding any text. /// A `Tokenizer` is capable of encoding/decoding any text.
pub struct Tokenizer<M, N> { pub struct Tokenizer<M, N, PT> {
// Tokenizer parts // Tokenizer parts
normalizer: Option<N>, normalizer: Option<N>,
pre_tokenizer: Option<Box<dyn PreTokenizer>>, pre_tokenizer: Option<PT>,
model: M, model: M,
post_processor: Option<Box<dyn PostProcessor>>, post_processor: Option<Box<dyn PostProcessor>>,
decoder: Option<Box<dyn Decoder>>, decoder: Option<Box<dyn Decoder>>,
@@ -335,10 +337,11 @@ pub struct Tokenizer<M, N> {
padding: Option<PaddingParams>, padding: Option<PaddingParams>,
} }
impl<M, N> Tokenizer<M, N> impl<M, N, PT> Tokenizer<M, N, PT>
where where
M: Model, M: Model,
N: Normalizer, N: Normalizer,
PT: PreTokenizer,
{ {
/// Instantiate a new Tokenizer, with the given Model /// Instantiate a new Tokenizer, with the given Model
pub fn new(model: M) -> Self { pub fn new(model: M) -> Self {
@@ -368,14 +371,13 @@ where
} }
/// Set the pre tokenizer /// Set the pre tokenizer
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Box<dyn PreTokenizer>) -> &Self { pub fn with_pre_tokenizer(&mut self, pre_tokenizer: PT) -> &Self {
self.pre_tokenizer = Some(pre_tokenizer); self.pre_tokenizer = Some(pre_tokenizer);
self self
} }
/// Get the pre tokenizer /// Get the pre tokenizer
#[allow(clippy::borrowed_box)] pub fn get_pre_tokenizer(&self) -> Option<&PT> {
pub fn get_pre_tokenizer(&self) -> Option<&Box<dyn PreTokenizer>> {
self.pre_tokenizer.as_ref() self.pre_tokenizer.as_ref()
} }
@@ -577,7 +579,8 @@ where
/// # use tokenizers::Tokenizer; /// # use tokenizers::Tokenizer;
/// # use tokenizers::models::bpe::BPE; /// # use tokenizers::models::bpe::BPE;
/// # use tokenizers::normalizers::NormalizerWrapper; /// # use tokenizers::normalizers::NormalizerWrapper;
/// # let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(BPE::default()); /// # use tokenizers::pre_tokenizers::PreTokenizerWrapper;
/// # let mut tokenizer = Tokenizer::<_, NormalizerWrapper, PreTokenizerWrapper>::new(BPE::default());
/// # /// #
/// // Sequences: /// // Sequences:
/// tokenizer.encode("Single sequence", false); /// tokenizer.encode("Single sequence", false);
@@ -752,7 +755,7 @@ where
} }
/// Train a model and replace our current Model, using the given Trainer /// Train a model and replace our current Model, using the given Trainer
pub fn train<T, TM>(self, trainer: &T, files: Vec<String>) -> Result<Tokenizer<TM, N>> pub fn train<T, TM>(self, trainer: &T, files: Vec<String>) -> Result<Tokenizer<TM, N, PT>>
where where
T: Trainer<Model = TM>, T: Trainer<Model = TM>,
TM: Model, TM: Model,
@@ -912,10 +915,11 @@ where
} }
} }
impl<M, N> std::str::FromStr for Tokenizer<M, N> impl<M, N, PT> std::str::FromStr for Tokenizer<M, N, PT>
where where
M: for<'de> Deserialize<'de> + Model, M: for<'de> Deserialize<'de> + Model,
N: for<'de> Deserialize<'de> + Normalizer, N: for<'de> Deserialize<'de> + Normalizer,
PT: for<'de> Deserialize<'de> + PreTokenizer,
{ {
type Err = Error; type Err = Error;
@@ -924,10 +928,11 @@ where
} }
} }
impl<M, N> Tokenizer<M, N> impl<M, N, PT> Tokenizer<M, N, PT>
where where
M: DeserializeOwned + Model, M: DeserializeOwned + Model,
N: DeserializeOwned + Normalizer, N: DeserializeOwned + Normalizer,
PT: DeserializeOwned + PreTokenizer,
{ {
/// Instantiate a new Tokenizer from the given file /// Instantiate a new Tokenizer from the given file
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> { pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
@@ -937,10 +942,11 @@ where
} }
} }
impl<M, N> Tokenizer<M, N> impl<M, N, PT> Tokenizer<M, N, PT>
where where
M: Serialize, M: Serialize,
N: Serialize, N: Serialize,
PT: Serialize,
{ {
/// Serialize the current tokenizer as a String /// Serialize the current tokenizer as a String
pub fn to_string(&self, pretty: bool) -> Result<String> { pub fn to_string(&self, pretty: bool) -> Result<String> {

View File

@@ -8,14 +8,15 @@ use serde::{
}; };
use super::{added_vocabulary::AddedTokenWithId, Tokenizer}; use super::{added_vocabulary::AddedTokenWithId, Tokenizer};
use crate::{Model, Normalizer, TokenizerBuilder}; use crate::{Model, Normalizer, PreTokenizer, TokenizerBuilder};
static SERIALIZATION_VERSION: &str = "1.0"; static SERIALIZATION_VERSION: &str = "1.0";
impl<M, N> Serialize for Tokenizer<M, N> impl<M, N, PT> Serialize for Tokenizer<M, N, PT>
where where
M: Serialize, M: Serialize,
N: Serialize, N: Serialize,
PT: Serialize,
{ {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
@@ -44,10 +45,11 @@ where
} }
} }
impl<'de, M, N> Deserialize<'de> for Tokenizer<M, N> impl<'de, M, N, PT> Deserialize<'de> for Tokenizer<M, N, PT>
where where
M: Deserialize<'de> + Model, M: Deserialize<'de> + Model,
N: Deserialize<'de> + Normalizer, N: Deserialize<'de> + Normalizer,
PT: Deserialize<'de> + PreTokenizer,
{ {
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error> fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
where where
@@ -66,19 +68,20 @@ where
"decoder", "decoder",
"model", "model",
], ],
TokenizerVisitor(PhantomData, PhantomData), TokenizerVisitor(PhantomData, PhantomData, PhantomData),
) )
} }
} }
struct TokenizerVisitor<M, N>(PhantomData<M>, PhantomData<N>); struct TokenizerVisitor<M, N, PT>(PhantomData<M>, PhantomData<N>, PhantomData<PT>);
impl<'de, M, N> Visitor<'de> for TokenizerVisitor<M, N> impl<'de, M, N, PT> Visitor<'de> for TokenizerVisitor<M, N, PT>
where where
M: Deserialize<'de> + Model, M: Deserialize<'de> + Model,
N: Deserialize<'de> + Normalizer, N: Deserialize<'de> + Normalizer,
PT: Deserialize<'de> + PreTokenizer,
{ {
type Value = Tokenizer<M, N>; type Value = Tokenizer<M, N, PT>;
fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(fmt, "struct Tokenizer") write!(fmt, "struct Tokenizer")

View File

@@ -5,11 +5,12 @@ use tokenizers::normalizers::bert::BertNormalizer;
use tokenizers::normalizers::NormalizerWrapper; use tokenizers::normalizers::NormalizerWrapper;
use tokenizers::pre_tokenizers::bert::BertPreTokenizer; use tokenizers::pre_tokenizers::bert::BertPreTokenizer;
use tokenizers::pre_tokenizers::byte_level::ByteLevel; use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::pre_tokenizers::PreTokenizerWrapper;
use tokenizers::processors::bert::BertProcessing; use tokenizers::processors::bert::BertProcessing;
use tokenizers::tokenizer::{Model, Tokenizer}; use tokenizers::tokenizer::{Model, Tokenizer};
#[allow(dead_code)] #[allow(dead_code)]
pub fn get_empty() -> Tokenizer<BPE, NormalizerWrapper> { pub fn get_empty() -> Tokenizer<BPE, NormalizerWrapper, PreTokenizerWrapper> {
Tokenizer::new(BPE::default()) Tokenizer::new(BPE::default())
} }
@@ -24,11 +25,9 @@ pub fn get_byte_level_bpe() -> BPE {
pub fn get_byte_level( pub fn get_byte_level(
add_prefix_space: bool, add_prefix_space: bool,
trim_offsets: bool, trim_offsets: bool,
) -> Tokenizer<BPE, NormalizerWrapper> { ) -> Tokenizer<BPE, NormalizerWrapper, ByteLevel> {
let mut tokenizer = Tokenizer::new(get_byte_level_bpe()); let mut tokenizer = Tokenizer::new(get_byte_level_bpe());
tokenizer.with_pre_tokenizer(Box::new( tokenizer.with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space));
ByteLevel::default().add_prefix_space(add_prefix_space),
));
tokenizer.with_decoder(Box::new(ByteLevel::default())); tokenizer.with_decoder(Box::new(ByteLevel::default()));
tokenizer.with_post_processor(Box::new(ByteLevel::default().trim_offsets(trim_offsets))); tokenizer.with_post_processor(Box::new(ByteLevel::default().trim_offsets(trim_offsets)));
@@ -43,10 +42,10 @@ pub fn get_bert_wordpiece() -> WordPiece {
} }
#[allow(dead_code)] #[allow(dead_code)]
pub fn get_bert() -> Tokenizer<WordPiece, BertNormalizer> { pub fn get_bert() -> Tokenizer<WordPiece, BertNormalizer, BertPreTokenizer> {
let mut tokenizer = Tokenizer::new(get_bert_wordpiece()); let mut tokenizer = Tokenizer::new(get_bert_wordpiece());
tokenizer.with_normalizer(BertNormalizer::default()); tokenizer.with_normalizer(BertNormalizer::default());
tokenizer.with_pre_tokenizer(Box::new(BertPreTokenizer)); tokenizer.with_pre_tokenizer(BertPreTokenizer);
tokenizer.with_decoder(Box::new(WordPieceDecoder::default())); tokenizer.with_decoder(Box::new(WordPieceDecoder::default()));
tokenizer.with_post_processor(Box::new(BertProcessing::new( tokenizer.with_post_processor(Box::new(BertProcessing::new(
( (