Make Tokenizer generic over pre-tokenizers.

This commit is contained in:
Sebastian Pütz
2020-07-25 15:57:51 +02:00
committed by Anthony MOI
parent 08b8c48127
commit bcc54a2ea1
6 changed files with 62 additions and 46 deletions

View File

@@ -11,27 +11,28 @@ use tokenizers::normalizers::NormalizerWrapper;
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::pre_tokenizers::whitespace::Whitespace;
use tokenizers::tokenizer::{AddedToken, EncodeInput, Tokenizer, Trainer};
use tokenizers::{Model, Normalizer};
use tokenizers::{Model, Normalizer, PreTokenizer};
static BATCH_SIZE: usize = 1_000;
fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer<BPE, NormalizerWrapper> {
fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer<BPE, NormalizerWrapper, ByteLevel> {
let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_decoder(Box::new(ByteLevel::default()));
tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]);
tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]);
tokenizer
}
fn iter_bench_encode<M, N>(
fn iter_bench_encode<M, N, PT>(
iters: u64,
tokenizer: &Tokenizer<M, N>,
tokenizer: &Tokenizer<M, N, PT>,
lines: &[EncodeInput],
) -> Duration
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
{
let mut duration = Duration::new(0, 0);
let mut line_index: usize = 0;
@@ -47,14 +48,15 @@ where
duration
}
fn iter_bench_encode_batch<M, N>(
fn iter_bench_encode_batch<M, N, PT>(
iters: u64,
tokenizer: &Tokenizer<M, N>,
tokenizer: &Tokenizer<M, N, PT>,
batches: &[Vec<EncodeInput>],
) -> Duration
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
{
let mut duration = Duration::new(0, 0);
let mut batch_index: usize = 0;
@@ -109,12 +111,17 @@ fn bench_gpt2(c: &mut Criterion) {
});
}
fn iter_bench_train<T: Trainer<Model = BPE>>(
fn iter_bench_train<T, M, PT>(
iters: u64,
mut tokenizer: Tokenizer<BPE, NormalizerWrapper>,
mut tokenizer: Tokenizer<M, NormalizerWrapper, PT>,
trainer: &T,
files: Vec<String>,
) -> Duration {
) -> Duration
where
M: Model,
PT: PreTokenizer,
T: Trainer<Model = M>,
{
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
@@ -129,7 +136,7 @@ fn bench_train(c: &mut Criterion) {
c.bench_function("BPE Train vocabulary (small)", |b| {
b.iter_custom(|iters| {
let mut tokenizer = Tokenizer::new(BPE::default());
tokenizer.with_pre_tokenizer(Box::new(Whitespace::default()));
tokenizer.with_pre_tokenizer(Whitespace::default());
iter_bench_train(
iters,
tokenizer,
@@ -142,7 +149,7 @@ fn bench_train(c: &mut Criterion) {
c.bench_function("BPE Train vocabulary (big)", |b| {
b.iter_custom(|iters| {
let mut tokenizer = Tokenizer::new(BPE::default());
tokenizer.with_pre_tokenizer(Box::new(Whitespace::default()));
tokenizer.with_pre_tokenizer(Whitespace::default());
iter_bench_train(iters, tokenizer, &trainer, vec!["data/big.txt".to_string()])
})
});

View File

@@ -18,8 +18,8 @@ fn shell(matches: &ArgMatches) -> Result<()> {
.expect("Must give a merges.txt file");
let bpe = BPE::from_files(vocab, merges).build()?;
let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(bpe);
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
let mut tokenizer = Tokenizer::<_, NormalizerWrapper, ByteLevel>::new(bpe);
tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_decoder(Box::new(ByteLevel::default()));
tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]);

View File

@@ -28,6 +28,7 @@
//! use tokenizers::tokenizer::{Result, Tokenizer, EncodeInput};
//! use tokenizers::models::bpe::BPE;
//! use tokenizers::normalizers::NormalizerWrapper;
//! use tokenizers::pre_tokenizers::PreTokenizerWrapper;
//!
//! fn main() -> Result<()> {
//! let bpe_builder = BPE::from_files("./path/to/vocab.json", "./path/to/merges.txt");
@@ -36,7 +37,7 @@
//! .unk_token("[UNK]".into())
//! .build()?;
//!
//! let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(bpe);
//! let mut tokenizer = Tokenizer::<_, NormalizerWrapper, PreTokenizerWrapper>::new(bpe);
//!
//! let encoding = tokenizer.encode("Hey there!", false)?;
//! println!("{:?}", encoding.get_tokens());

View File

@@ -213,10 +213,10 @@ impl fmt::Display for BuilderError {
/// Builder for Tokenizer structs.
///
/// `build()` fails if the `model` is missing.
pub struct TokenizerBuilder<M, N> {
pub struct TokenizerBuilder<M, N, PT> {
model: Option<M>,
normalizer: Option<N>,
pre_tokenizer: Option<Box<dyn PreTokenizer>>,
pre_tokenizer: Option<PT>,
post_processor: Option<Box<dyn PostProcessor>>,
decoder: Option<Box<dyn Decoder>>,
@@ -226,20 +226,22 @@ pub struct TokenizerBuilder<M, N> {
padding: Option<PaddingParams>,
}
impl<M, N> Default for TokenizerBuilder<M, N>
impl<M, N, PT> Default for TokenizerBuilder<M, N, PT>
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
{
fn default() -> Self {
Self::new()
}
}
impl<M, N> TokenizerBuilder<M, N>
impl<M, N, PT> TokenizerBuilder<M, N, PT>
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
{
/// Get an empty TokenizerBuilder.
pub fn new() -> Self {
@@ -258,7 +260,7 @@ where
/// Convert the TokenizerBuilder to a Tokenizer.
///
/// Conversion fails if the `model` is missing.
pub fn build(self) -> Result<Tokenizer<M, N>> {
pub fn build(self) -> Result<Tokenizer<M, N, PT>> {
let model = self
.model
.ok_or_else(|| Box::new(BuilderError("Model missing.".into())))?;
@@ -288,7 +290,7 @@ where
}
/// Set the pretokenizer.
pub fn with_pretokenizer(mut self, pretokenizer: Option<Box<dyn PreTokenizer>>) -> Self {
pub fn with_pretokenizer(mut self, pretokenizer: Option<PT>) -> Self {
self.pre_tokenizer = pretokenizer;
self
}
@@ -319,10 +321,10 @@ where
}
/// A `Tokenizer` is capable of encoding/decoding any text.
pub struct Tokenizer<M, N> {
pub struct Tokenizer<M, N, PT> {
// Tokenizer parts
normalizer: Option<N>,
pre_tokenizer: Option<Box<dyn PreTokenizer>>,
pre_tokenizer: Option<PT>,
model: M,
post_processor: Option<Box<dyn PostProcessor>>,
decoder: Option<Box<dyn Decoder>>,
@@ -335,10 +337,11 @@ pub struct Tokenizer<M, N> {
padding: Option<PaddingParams>,
}
impl<M, N> Tokenizer<M, N>
impl<M, N, PT> Tokenizer<M, N, PT>
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
{
/// Instantiate a new Tokenizer, with the given Model
pub fn new(model: M) -> Self {
@@ -368,14 +371,13 @@ where
}
/// Set the pre tokenizer
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Box<dyn PreTokenizer>) -> &Self {
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: PT) -> &Self {
self.pre_tokenizer = Some(pre_tokenizer);
self
}
/// Get the pre tokenizer
#[allow(clippy::borrowed_box)]
pub fn get_pre_tokenizer(&self) -> Option<&Box<dyn PreTokenizer>> {
pub fn get_pre_tokenizer(&self) -> Option<&PT> {
self.pre_tokenizer.as_ref()
}
@@ -577,7 +579,8 @@ where
/// # use tokenizers::Tokenizer;
/// # use tokenizers::models::bpe::BPE;
/// # use tokenizers::normalizers::NormalizerWrapper;
/// # let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(BPE::default());
/// # use tokenizers::pre_tokenizers::PreTokenizerWrapper;
/// # let mut tokenizer = Tokenizer::<_, NormalizerWrapper, PreTokenizerWrapper>::new(BPE::default());
/// #
/// // Sequences:
/// tokenizer.encode("Single sequence", false);
@@ -752,7 +755,7 @@ where
}
/// Train a model and replace our current Model, using the given Trainer
pub fn train<T, TM>(self, trainer: &T, files: Vec<String>) -> Result<Tokenizer<TM, N>>
pub fn train<T, TM>(self, trainer: &T, files: Vec<String>) -> Result<Tokenizer<TM, N, PT>>
where
T: Trainer<Model = TM>,
TM: Model,
@@ -912,10 +915,11 @@ where
}
}
impl<M, N> std::str::FromStr for Tokenizer<M, N>
impl<M, N, PT> std::str::FromStr for Tokenizer<M, N, PT>
where
M: for<'de> Deserialize<'de> + Model,
N: for<'de> Deserialize<'de> + Normalizer,
PT: for<'de> Deserialize<'de> + PreTokenizer,
{
type Err = Error;
@@ -924,10 +928,11 @@ where
}
}
impl<M, N> Tokenizer<M, N>
impl<M, N, PT> Tokenizer<M, N, PT>
where
M: DeserializeOwned + Model,
N: DeserializeOwned + Normalizer,
PT: DeserializeOwned + PreTokenizer,
{
/// Instantiate a new Tokenizer from the given file
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
@@ -937,10 +942,11 @@ where
}
}
impl<M, N> Tokenizer<M, N>
impl<M, N, PT> Tokenizer<M, N, PT>
where
M: Serialize,
N: Serialize,
PT: Serialize,
{
/// Serialize the current tokenizer as a String
pub fn to_string(&self, pretty: bool) -> Result<String> {

View File

@@ -8,14 +8,15 @@ use serde::{
};
use super::{added_vocabulary::AddedTokenWithId, Tokenizer};
use crate::{Model, Normalizer, TokenizerBuilder};
use crate::{Model, Normalizer, PreTokenizer, TokenizerBuilder};
static SERIALIZATION_VERSION: &str = "1.0";
impl<M, N> Serialize for Tokenizer<M, N>
impl<M, N, PT> Serialize for Tokenizer<M, N, PT>
where
M: Serialize,
N: Serialize,
PT: Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
@@ -44,10 +45,11 @@ where
}
}
impl<'de, M, N> Deserialize<'de> for Tokenizer<M, N>
impl<'de, M, N, PT> Deserialize<'de> for Tokenizer<M, N, PT>
where
M: Deserialize<'de> + Model,
N: Deserialize<'de> + Normalizer,
PT: Deserialize<'de> + PreTokenizer,
{
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
where
@@ -66,19 +68,20 @@ where
"decoder",
"model",
],
TokenizerVisitor(PhantomData, PhantomData),
TokenizerVisitor(PhantomData, PhantomData, PhantomData),
)
}
}
struct TokenizerVisitor<M, N>(PhantomData<M>, PhantomData<N>);
struct TokenizerVisitor<M, N, PT>(PhantomData<M>, PhantomData<N>, PhantomData<PT>);
impl<'de, M, N> Visitor<'de> for TokenizerVisitor<M, N>
impl<'de, M, N, PT> Visitor<'de> for TokenizerVisitor<M, N, PT>
where
M: Deserialize<'de> + Model,
N: Deserialize<'de> + Normalizer,
PT: Deserialize<'de> + PreTokenizer,
{
type Value = Tokenizer<M, N>;
type Value = Tokenizer<M, N, PT>;
fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(fmt, "struct Tokenizer")

View File

@@ -5,11 +5,12 @@ use tokenizers::normalizers::bert::BertNormalizer;
use tokenizers::normalizers::NormalizerWrapper;
use tokenizers::pre_tokenizers::bert::BertPreTokenizer;
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::pre_tokenizers::PreTokenizerWrapper;
use tokenizers::processors::bert::BertProcessing;
use tokenizers::tokenizer::{Model, Tokenizer};
#[allow(dead_code)]
pub fn get_empty() -> Tokenizer<BPE, NormalizerWrapper> {
pub fn get_empty() -> Tokenizer<BPE, NormalizerWrapper, PreTokenizerWrapper> {
Tokenizer::new(BPE::default())
}
@@ -24,11 +25,9 @@ pub fn get_byte_level_bpe() -> BPE {
pub fn get_byte_level(
add_prefix_space: bool,
trim_offsets: bool,
) -> Tokenizer<BPE, NormalizerWrapper> {
) -> Tokenizer<BPE, NormalizerWrapper, ByteLevel> {
let mut tokenizer = Tokenizer::new(get_byte_level_bpe());
tokenizer.with_pre_tokenizer(Box::new(
ByteLevel::default().add_prefix_space(add_prefix_space),
));
tokenizer.with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space));
tokenizer.with_decoder(Box::new(ByteLevel::default()));
tokenizer.with_post_processor(Box::new(ByteLevel::default().trim_offsets(trim_offsets)));
@@ -43,10 +42,10 @@ pub fn get_bert_wordpiece() -> WordPiece {
}
#[allow(dead_code)]
pub fn get_bert() -> Tokenizer<WordPiece, BertNormalizer> {
pub fn get_bert() -> Tokenizer<WordPiece, BertNormalizer, BertPreTokenizer> {
let mut tokenizer = Tokenizer::new(get_bert_wordpiece());
tokenizer.with_normalizer(BertNormalizer::default());
tokenizer.with_pre_tokenizer(Box::new(BertPreTokenizer));
tokenizer.with_pre_tokenizer(BertPreTokenizer);
tokenizer.with_decoder(Box::new(WordPieceDecoder::default()));
tokenizer.with_post_processor(Box::new(BertProcessing::new(
(