diff --git a/tokenizers/benches/bpe_benchmark.rs b/tokenizers/benches/bpe_benchmark.rs index e20b3f33..270374fd 100644 --- a/tokenizers/benches/bpe_benchmark.rs +++ b/tokenizers/benches/bpe_benchmark.rs @@ -11,27 +11,28 @@ use tokenizers::normalizers::NormalizerWrapper; use tokenizers::pre_tokenizers::byte_level::ByteLevel; use tokenizers::pre_tokenizers::whitespace::Whitespace; use tokenizers::tokenizer::{AddedToken, EncodeInput, Tokenizer, Trainer}; -use tokenizers::{Model, Normalizer}; +use tokenizers::{Model, Normalizer, PreTokenizer}; static BATCH_SIZE: usize = 1_000; -fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer { +fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer { let mut tokenizer = Tokenizer::new(bpe); - tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default())); + tokenizer.with_pre_tokenizer(ByteLevel::default()); tokenizer.with_decoder(Box::new(ByteLevel::default())); tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]); tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]); tokenizer } -fn iter_bench_encode( +fn iter_bench_encode( iters: u64, - tokenizer: &Tokenizer, + tokenizer: &Tokenizer, lines: &[EncodeInput], ) -> Duration where M: Model, N: Normalizer, + PT: PreTokenizer, { let mut duration = Duration::new(0, 0); let mut line_index: usize = 0; @@ -47,14 +48,15 @@ where duration } -fn iter_bench_encode_batch( +fn iter_bench_encode_batch( iters: u64, - tokenizer: &Tokenizer, + tokenizer: &Tokenizer, batches: &[Vec], ) -> Duration where M: Model, N: Normalizer, + PT: PreTokenizer, { let mut duration = Duration::new(0, 0); let mut batch_index: usize = 0; @@ -109,12 +111,17 @@ fn bench_gpt2(c: &mut Criterion) { }); } -fn iter_bench_train>( +fn iter_bench_train( iters: u64, - mut tokenizer: Tokenizer, + mut tokenizer: Tokenizer, trainer: &T, files: Vec, -) -> Duration { +) -> Duration +where + M: Model, + PT: PreTokenizer, + T: Trainer, +{ let mut duration = Duration::new(0, 0); for _i in 0..iters { let start = Instant::now(); @@ -129,7 +136,7 @@ fn bench_train(c: &mut Criterion) { c.bench_function("BPE Train vocabulary (small)", |b| { b.iter_custom(|iters| { let mut tokenizer = Tokenizer::new(BPE::default()); - tokenizer.with_pre_tokenizer(Box::new(Whitespace::default())); + tokenizer.with_pre_tokenizer(Whitespace::default()); iter_bench_train( iters, tokenizer, @@ -142,7 +149,7 @@ fn bench_train(c: &mut Criterion) { c.bench_function("BPE Train vocabulary (big)", |b| { b.iter_custom(|iters| { let mut tokenizer = Tokenizer::new(BPE::default()); - tokenizer.with_pre_tokenizer(Box::new(Whitespace::default())); + tokenizer.with_pre_tokenizer(Whitespace::default()); iter_bench_train(iters, tokenizer, &trainer, vec!["data/big.txt".to_string()]) }) }); diff --git a/tokenizers/src/cli.rs b/tokenizers/src/cli.rs index 155b6f61..c1f09902 100644 --- a/tokenizers/src/cli.rs +++ b/tokenizers/src/cli.rs @@ -18,8 +18,8 @@ fn shell(matches: &ArgMatches) -> Result<()> { .expect("Must give a merges.txt file"); let bpe = BPE::from_files(vocab, merges).build()?; - let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(bpe); - tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default())); + let mut tokenizer = Tokenizer::<_, NormalizerWrapper, ByteLevel>::new(bpe); + tokenizer.with_pre_tokenizer(ByteLevel::default()); tokenizer.with_decoder(Box::new(ByteLevel::default())); tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]); diff --git a/tokenizers/src/lib.rs b/tokenizers/src/lib.rs index 833dad58..7e46af46 100644 --- a/tokenizers/src/lib.rs +++ b/tokenizers/src/lib.rs @@ -28,6 +28,7 @@ //! use tokenizers::tokenizer::{Result, Tokenizer, EncodeInput}; //! use tokenizers::models::bpe::BPE; //! use tokenizers::normalizers::NormalizerWrapper; +//! use tokenizers::pre_tokenizers::PreTokenizerWrapper; //! //! fn main() -> Result<()> { //! let bpe_builder = BPE::from_files("./path/to/vocab.json", "./path/to/merges.txt"); @@ -36,7 +37,7 @@ //! .unk_token("[UNK]".into()) //! .build()?; //! -//! let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(bpe); +//! let mut tokenizer = Tokenizer::<_, NormalizerWrapper, PreTokenizerWrapper>::new(bpe); //! //! let encoding = tokenizer.encode("Hey there!", false)?; //! println!("{:?}", encoding.get_tokens()); diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 3f94cab2..3545b532 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -213,10 +213,10 @@ impl fmt::Display for BuilderError { /// Builder for Tokenizer structs. /// /// `build()` fails if the `model` is missing. -pub struct TokenizerBuilder { +pub struct TokenizerBuilder { model: Option, normalizer: Option, - pre_tokenizer: Option>, + pre_tokenizer: Option, post_processor: Option>, decoder: Option>, @@ -226,20 +226,22 @@ pub struct TokenizerBuilder { padding: Option, } -impl Default for TokenizerBuilder +impl Default for TokenizerBuilder where M: Model, N: Normalizer, + PT: PreTokenizer, { fn default() -> Self { Self::new() } } -impl TokenizerBuilder +impl TokenizerBuilder where M: Model, N: Normalizer, + PT: PreTokenizer, { /// Get an empty TokenizerBuilder. pub fn new() -> Self { @@ -258,7 +260,7 @@ where /// Convert the TokenizerBuilder to a Tokenizer. /// /// Conversion fails if the `model` is missing. - pub fn build(self) -> Result> { + pub fn build(self) -> Result> { let model = self .model .ok_or_else(|| Box::new(BuilderError("Model missing.".into())))?; @@ -288,7 +290,7 @@ where } /// Set the pretokenizer. - pub fn with_pretokenizer(mut self, pretokenizer: Option>) -> Self { + pub fn with_pretokenizer(mut self, pretokenizer: Option) -> Self { self.pre_tokenizer = pretokenizer; self } @@ -319,10 +321,10 @@ where } /// A `Tokenizer` is capable of encoding/decoding any text. -pub struct Tokenizer { +pub struct Tokenizer { // Tokenizer parts normalizer: Option, - pre_tokenizer: Option>, + pre_tokenizer: Option, model: M, post_processor: Option>, decoder: Option>, @@ -335,10 +337,11 @@ pub struct Tokenizer { padding: Option, } -impl Tokenizer +impl Tokenizer where M: Model, N: Normalizer, + PT: PreTokenizer, { /// Instantiate a new Tokenizer, with the given Model pub fn new(model: M) -> Self { @@ -368,14 +371,13 @@ where } /// Set the pre tokenizer - pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Box) -> &Self { + pub fn with_pre_tokenizer(&mut self, pre_tokenizer: PT) -> &Self { self.pre_tokenizer = Some(pre_tokenizer); self } /// Get the pre tokenizer - #[allow(clippy::borrowed_box)] - pub fn get_pre_tokenizer(&self) -> Option<&Box> { + pub fn get_pre_tokenizer(&self) -> Option<&PT> { self.pre_tokenizer.as_ref() } @@ -577,7 +579,8 @@ where /// # use tokenizers::Tokenizer; /// # use tokenizers::models::bpe::BPE; /// # use tokenizers::normalizers::NormalizerWrapper; - /// # let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(BPE::default()); + /// # use tokenizers::pre_tokenizers::PreTokenizerWrapper; + /// # let mut tokenizer = Tokenizer::<_, NormalizerWrapper, PreTokenizerWrapper>::new(BPE::default()); /// # /// // Sequences: /// tokenizer.encode("Single sequence", false); @@ -752,7 +755,7 @@ where } /// Train a model and replace our current Model, using the given Trainer - pub fn train(self, trainer: &T, files: Vec) -> Result> + pub fn train(self, trainer: &T, files: Vec) -> Result> where T: Trainer, TM: Model, @@ -912,10 +915,11 @@ where } } -impl std::str::FromStr for Tokenizer +impl std::str::FromStr for Tokenizer where M: for<'de> Deserialize<'de> + Model, N: for<'de> Deserialize<'de> + Normalizer, + PT: for<'de> Deserialize<'de> + PreTokenizer, { type Err = Error; @@ -924,10 +928,11 @@ where } } -impl Tokenizer +impl Tokenizer where M: DeserializeOwned + Model, N: DeserializeOwned + Normalizer, + PT: DeserializeOwned + PreTokenizer, { /// Instantiate a new Tokenizer from the given file pub fn from_file>(file: P) -> Result { @@ -937,10 +942,11 @@ where } } -impl Tokenizer +impl Tokenizer where M: Serialize, N: Serialize, + PT: Serialize, { /// Serialize the current tokenizer as a String pub fn to_string(&self, pretty: bool) -> Result { diff --git a/tokenizers/src/tokenizer/serialization.rs b/tokenizers/src/tokenizer/serialization.rs index 07ec106a..473f32bf 100644 --- a/tokenizers/src/tokenizer/serialization.rs +++ b/tokenizers/src/tokenizer/serialization.rs @@ -8,14 +8,15 @@ use serde::{ }; use super::{added_vocabulary::AddedTokenWithId, Tokenizer}; -use crate::{Model, Normalizer, TokenizerBuilder}; +use crate::{Model, Normalizer, PreTokenizer, TokenizerBuilder}; static SERIALIZATION_VERSION: &str = "1.0"; -impl Serialize for Tokenizer +impl Serialize for Tokenizer where M: Serialize, N: Serialize, + PT: Serialize, { fn serialize(&self, serializer: S) -> Result where @@ -44,10 +45,11 @@ where } } -impl<'de, M, N> Deserialize<'de> for Tokenizer +impl<'de, M, N, PT> Deserialize<'de> for Tokenizer where M: Deserialize<'de> + Model, N: Deserialize<'de> + Normalizer, + PT: Deserialize<'de> + PreTokenizer, { fn deserialize(deserializer: De) -> Result where @@ -66,19 +68,20 @@ where "decoder", "model", ], - TokenizerVisitor(PhantomData, PhantomData), + TokenizerVisitor(PhantomData, PhantomData, PhantomData), ) } } -struct TokenizerVisitor(PhantomData, PhantomData); +struct TokenizerVisitor(PhantomData, PhantomData, PhantomData); -impl<'de, M, N> Visitor<'de> for TokenizerVisitor +impl<'de, M, N, PT> Visitor<'de> for TokenizerVisitor where M: Deserialize<'de> + Model, N: Deserialize<'de> + Normalizer, + PT: Deserialize<'de> + PreTokenizer, { - type Value = Tokenizer; + type Value = Tokenizer; fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { write!(fmt, "struct Tokenizer") diff --git a/tokenizers/tests/common/mod.rs b/tokenizers/tests/common/mod.rs index e78a1ead..1afd563e 100644 --- a/tokenizers/tests/common/mod.rs +++ b/tokenizers/tests/common/mod.rs @@ -5,11 +5,12 @@ use tokenizers::normalizers::bert::BertNormalizer; use tokenizers::normalizers::NormalizerWrapper; use tokenizers::pre_tokenizers::bert::BertPreTokenizer; use tokenizers::pre_tokenizers::byte_level::ByteLevel; +use tokenizers::pre_tokenizers::PreTokenizerWrapper; use tokenizers::processors::bert::BertProcessing; use tokenizers::tokenizer::{Model, Tokenizer}; #[allow(dead_code)] -pub fn get_empty() -> Tokenizer { +pub fn get_empty() -> Tokenizer { Tokenizer::new(BPE::default()) } @@ -24,11 +25,9 @@ pub fn get_byte_level_bpe() -> BPE { pub fn get_byte_level( add_prefix_space: bool, trim_offsets: bool, -) -> Tokenizer { +) -> Tokenizer { let mut tokenizer = Tokenizer::new(get_byte_level_bpe()); - tokenizer.with_pre_tokenizer(Box::new( - ByteLevel::default().add_prefix_space(add_prefix_space), - )); + tokenizer.with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space)); tokenizer.with_decoder(Box::new(ByteLevel::default())); tokenizer.with_post_processor(Box::new(ByteLevel::default().trim_offsets(trim_offsets))); @@ -43,10 +42,10 @@ pub fn get_bert_wordpiece() -> WordPiece { } #[allow(dead_code)] -pub fn get_bert() -> Tokenizer { +pub fn get_bert() -> Tokenizer { let mut tokenizer = Tokenizer::new(get_bert_wordpiece()); tokenizer.with_normalizer(BertNormalizer::default()); - tokenizer.with_pre_tokenizer(Box::new(BertPreTokenizer)); + tokenizer.with_pre_tokenizer(BertPreTokenizer); tokenizer.with_decoder(Box::new(WordPieceDecoder::default())); tokenizer.with_post_processor(Box::new(BertProcessing::new( (