mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-09 22:28:29 +00:00
Make Tokenizer generic over pre-tokenizers.
This commit is contained in:
committed by
Anthony MOI
parent
08b8c48127
commit
bcc54a2ea1
@@ -11,27 +11,28 @@ use tokenizers::normalizers::NormalizerWrapper;
|
||||
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
||||
use tokenizers::pre_tokenizers::whitespace::Whitespace;
|
||||
use tokenizers::tokenizer::{AddedToken, EncodeInput, Tokenizer, Trainer};
|
||||
use tokenizers::{Model, Normalizer};
|
||||
use tokenizers::{Model, Normalizer, PreTokenizer};
|
||||
|
||||
static BATCH_SIZE: usize = 1_000;
|
||||
|
||||
fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer<BPE, NormalizerWrapper> {
|
||||
fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer<BPE, NormalizerWrapper, ByteLevel> {
|
||||
let mut tokenizer = Tokenizer::new(bpe);
|
||||
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
|
||||
tokenizer.with_pre_tokenizer(ByteLevel::default());
|
||||
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
||||
tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]);
|
||||
tokenizer
|
||||
}
|
||||
|
||||
fn iter_bench_encode<M, N>(
|
||||
fn iter_bench_encode<M, N, PT>(
|
||||
iters: u64,
|
||||
tokenizer: &Tokenizer<M, N>,
|
||||
tokenizer: &Tokenizer<M, N, PT>,
|
||||
lines: &[EncodeInput],
|
||||
) -> Duration
|
||||
where
|
||||
M: Model,
|
||||
N: Normalizer,
|
||||
PT: PreTokenizer,
|
||||
{
|
||||
let mut duration = Duration::new(0, 0);
|
||||
let mut line_index: usize = 0;
|
||||
@@ -47,14 +48,15 @@ where
|
||||
duration
|
||||
}
|
||||
|
||||
fn iter_bench_encode_batch<M, N>(
|
||||
fn iter_bench_encode_batch<M, N, PT>(
|
||||
iters: u64,
|
||||
tokenizer: &Tokenizer<M, N>,
|
||||
tokenizer: &Tokenizer<M, N, PT>,
|
||||
batches: &[Vec<EncodeInput>],
|
||||
) -> Duration
|
||||
where
|
||||
M: Model,
|
||||
N: Normalizer,
|
||||
PT: PreTokenizer,
|
||||
{
|
||||
let mut duration = Duration::new(0, 0);
|
||||
let mut batch_index: usize = 0;
|
||||
@@ -109,12 +111,17 @@ fn bench_gpt2(c: &mut Criterion) {
|
||||
});
|
||||
}
|
||||
|
||||
fn iter_bench_train<T: Trainer<Model = BPE>>(
|
||||
fn iter_bench_train<T, M, PT>(
|
||||
iters: u64,
|
||||
mut tokenizer: Tokenizer<BPE, NormalizerWrapper>,
|
||||
mut tokenizer: Tokenizer<M, NormalizerWrapper, PT>,
|
||||
trainer: &T,
|
||||
files: Vec<String>,
|
||||
) -> Duration {
|
||||
) -> Duration
|
||||
where
|
||||
M: Model,
|
||||
PT: PreTokenizer,
|
||||
T: Trainer<Model = M>,
|
||||
{
|
||||
let mut duration = Duration::new(0, 0);
|
||||
for _i in 0..iters {
|
||||
let start = Instant::now();
|
||||
@@ -129,7 +136,7 @@ fn bench_train(c: &mut Criterion) {
|
||||
c.bench_function("BPE Train vocabulary (small)", |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let mut tokenizer = Tokenizer::new(BPE::default());
|
||||
tokenizer.with_pre_tokenizer(Box::new(Whitespace::default()));
|
||||
tokenizer.with_pre_tokenizer(Whitespace::default());
|
||||
iter_bench_train(
|
||||
iters,
|
||||
tokenizer,
|
||||
@@ -142,7 +149,7 @@ fn bench_train(c: &mut Criterion) {
|
||||
c.bench_function("BPE Train vocabulary (big)", |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let mut tokenizer = Tokenizer::new(BPE::default());
|
||||
tokenizer.with_pre_tokenizer(Box::new(Whitespace::default()));
|
||||
tokenizer.with_pre_tokenizer(Whitespace::default());
|
||||
iter_bench_train(iters, tokenizer, &trainer, vec!["data/big.txt".to_string()])
|
||||
})
|
||||
});
|
||||
|
||||
@@ -18,8 +18,8 @@ fn shell(matches: &ArgMatches) -> Result<()> {
|
||||
.expect("Must give a merges.txt file");
|
||||
|
||||
let bpe = BPE::from_files(vocab, merges).build()?;
|
||||
let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(bpe);
|
||||
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
|
||||
let mut tokenizer = Tokenizer::<_, NormalizerWrapper, ByteLevel>::new(bpe);
|
||||
tokenizer.with_pre_tokenizer(ByteLevel::default());
|
||||
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
||||
|
||||
tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]);
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
//! use tokenizers::tokenizer::{Result, Tokenizer, EncodeInput};
|
||||
//! use tokenizers::models::bpe::BPE;
|
||||
//! use tokenizers::normalizers::NormalizerWrapper;
|
||||
//! use tokenizers::pre_tokenizers::PreTokenizerWrapper;
|
||||
//!
|
||||
//! fn main() -> Result<()> {
|
||||
//! let bpe_builder = BPE::from_files("./path/to/vocab.json", "./path/to/merges.txt");
|
||||
@@ -36,7 +37,7 @@
|
||||
//! .unk_token("[UNK]".into())
|
||||
//! .build()?;
|
||||
//!
|
||||
//! let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(bpe);
|
||||
//! let mut tokenizer = Tokenizer::<_, NormalizerWrapper, PreTokenizerWrapper>::new(bpe);
|
||||
//!
|
||||
//! let encoding = tokenizer.encode("Hey there!", false)?;
|
||||
//! println!("{:?}", encoding.get_tokens());
|
||||
|
||||
@@ -213,10 +213,10 @@ impl fmt::Display for BuilderError {
|
||||
/// Builder for Tokenizer structs.
|
||||
///
|
||||
/// `build()` fails if the `model` is missing.
|
||||
pub struct TokenizerBuilder<M, N> {
|
||||
pub struct TokenizerBuilder<M, N, PT> {
|
||||
model: Option<M>,
|
||||
normalizer: Option<N>,
|
||||
pre_tokenizer: Option<Box<dyn PreTokenizer>>,
|
||||
pre_tokenizer: Option<PT>,
|
||||
post_processor: Option<Box<dyn PostProcessor>>,
|
||||
decoder: Option<Box<dyn Decoder>>,
|
||||
|
||||
@@ -226,20 +226,22 @@ pub struct TokenizerBuilder<M, N> {
|
||||
padding: Option<PaddingParams>,
|
||||
}
|
||||
|
||||
impl<M, N> Default for TokenizerBuilder<M, N>
|
||||
impl<M, N, PT> Default for TokenizerBuilder<M, N, PT>
|
||||
where
|
||||
M: Model,
|
||||
N: Normalizer,
|
||||
PT: PreTokenizer,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<M, N> TokenizerBuilder<M, N>
|
||||
impl<M, N, PT> TokenizerBuilder<M, N, PT>
|
||||
where
|
||||
M: Model,
|
||||
N: Normalizer,
|
||||
PT: PreTokenizer,
|
||||
{
|
||||
/// Get an empty TokenizerBuilder.
|
||||
pub fn new() -> Self {
|
||||
@@ -258,7 +260,7 @@ where
|
||||
/// Convert the TokenizerBuilder to a Tokenizer.
|
||||
///
|
||||
/// Conversion fails if the `model` is missing.
|
||||
pub fn build(self) -> Result<Tokenizer<M, N>> {
|
||||
pub fn build(self) -> Result<Tokenizer<M, N, PT>> {
|
||||
let model = self
|
||||
.model
|
||||
.ok_or_else(|| Box::new(BuilderError("Model missing.".into())))?;
|
||||
@@ -288,7 +290,7 @@ where
|
||||
}
|
||||
|
||||
/// Set the pretokenizer.
|
||||
pub fn with_pretokenizer(mut self, pretokenizer: Option<Box<dyn PreTokenizer>>) -> Self {
|
||||
pub fn with_pretokenizer(mut self, pretokenizer: Option<PT>) -> Self {
|
||||
self.pre_tokenizer = pretokenizer;
|
||||
self
|
||||
}
|
||||
@@ -319,10 +321,10 @@ where
|
||||
}
|
||||
|
||||
/// A `Tokenizer` is capable of encoding/decoding any text.
|
||||
pub struct Tokenizer<M, N> {
|
||||
pub struct Tokenizer<M, N, PT> {
|
||||
// Tokenizer parts
|
||||
normalizer: Option<N>,
|
||||
pre_tokenizer: Option<Box<dyn PreTokenizer>>,
|
||||
pre_tokenizer: Option<PT>,
|
||||
model: M,
|
||||
post_processor: Option<Box<dyn PostProcessor>>,
|
||||
decoder: Option<Box<dyn Decoder>>,
|
||||
@@ -335,10 +337,11 @@ pub struct Tokenizer<M, N> {
|
||||
padding: Option<PaddingParams>,
|
||||
}
|
||||
|
||||
impl<M, N> Tokenizer<M, N>
|
||||
impl<M, N, PT> Tokenizer<M, N, PT>
|
||||
where
|
||||
M: Model,
|
||||
N: Normalizer,
|
||||
PT: PreTokenizer,
|
||||
{
|
||||
/// Instantiate a new Tokenizer, with the given Model
|
||||
pub fn new(model: M) -> Self {
|
||||
@@ -368,14 +371,13 @@ where
|
||||
}
|
||||
|
||||
/// Set the pre tokenizer
|
||||
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Box<dyn PreTokenizer>) -> &Self {
|
||||
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: PT) -> &Self {
|
||||
self.pre_tokenizer = Some(pre_tokenizer);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the pre tokenizer
|
||||
#[allow(clippy::borrowed_box)]
|
||||
pub fn get_pre_tokenizer(&self) -> Option<&Box<dyn PreTokenizer>> {
|
||||
pub fn get_pre_tokenizer(&self) -> Option<&PT> {
|
||||
self.pre_tokenizer.as_ref()
|
||||
}
|
||||
|
||||
@@ -577,7 +579,8 @@ where
|
||||
/// # use tokenizers::Tokenizer;
|
||||
/// # use tokenizers::models::bpe::BPE;
|
||||
/// # use tokenizers::normalizers::NormalizerWrapper;
|
||||
/// # let mut tokenizer = Tokenizer::<_, NormalizerWrapper>::new(BPE::default());
|
||||
/// # use tokenizers::pre_tokenizers::PreTokenizerWrapper;
|
||||
/// # let mut tokenizer = Tokenizer::<_, NormalizerWrapper, PreTokenizerWrapper>::new(BPE::default());
|
||||
/// #
|
||||
/// // Sequences:
|
||||
/// tokenizer.encode("Single sequence", false);
|
||||
@@ -752,7 +755,7 @@ where
|
||||
}
|
||||
|
||||
/// Train a model and replace our current Model, using the given Trainer
|
||||
pub fn train<T, TM>(self, trainer: &T, files: Vec<String>) -> Result<Tokenizer<TM, N>>
|
||||
pub fn train<T, TM>(self, trainer: &T, files: Vec<String>) -> Result<Tokenizer<TM, N, PT>>
|
||||
where
|
||||
T: Trainer<Model = TM>,
|
||||
TM: Model,
|
||||
@@ -912,10 +915,11 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<M, N> std::str::FromStr for Tokenizer<M, N>
|
||||
impl<M, N, PT> std::str::FromStr for Tokenizer<M, N, PT>
|
||||
where
|
||||
M: for<'de> Deserialize<'de> + Model,
|
||||
N: for<'de> Deserialize<'de> + Normalizer,
|
||||
PT: for<'de> Deserialize<'de> + PreTokenizer,
|
||||
{
|
||||
type Err = Error;
|
||||
|
||||
@@ -924,10 +928,11 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<M, N> Tokenizer<M, N>
|
||||
impl<M, N, PT> Tokenizer<M, N, PT>
|
||||
where
|
||||
M: DeserializeOwned + Model,
|
||||
N: DeserializeOwned + Normalizer,
|
||||
PT: DeserializeOwned + PreTokenizer,
|
||||
{
|
||||
/// Instantiate a new Tokenizer from the given file
|
||||
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
|
||||
@@ -937,10 +942,11 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<M, N> Tokenizer<M, N>
|
||||
impl<M, N, PT> Tokenizer<M, N, PT>
|
||||
where
|
||||
M: Serialize,
|
||||
N: Serialize,
|
||||
PT: Serialize,
|
||||
{
|
||||
/// Serialize the current tokenizer as a String
|
||||
pub fn to_string(&self, pretty: bool) -> Result<String> {
|
||||
|
||||
@@ -8,14 +8,15 @@ use serde::{
|
||||
};
|
||||
|
||||
use super::{added_vocabulary::AddedTokenWithId, Tokenizer};
|
||||
use crate::{Model, Normalizer, TokenizerBuilder};
|
||||
use crate::{Model, Normalizer, PreTokenizer, TokenizerBuilder};
|
||||
|
||||
static SERIALIZATION_VERSION: &str = "1.0";
|
||||
|
||||
impl<M, N> Serialize for Tokenizer<M, N>
|
||||
impl<M, N, PT> Serialize for Tokenizer<M, N, PT>
|
||||
where
|
||||
M: Serialize,
|
||||
N: Serialize,
|
||||
PT: Serialize,
|
||||
{
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
@@ -44,10 +45,11 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, M, N> Deserialize<'de> for Tokenizer<M, N>
|
||||
impl<'de, M, N, PT> Deserialize<'de> for Tokenizer<M, N, PT>
|
||||
where
|
||||
M: Deserialize<'de> + Model,
|
||||
N: Deserialize<'de> + Normalizer,
|
||||
PT: Deserialize<'de> + PreTokenizer,
|
||||
{
|
||||
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
|
||||
where
|
||||
@@ -66,19 +68,20 @@ where
|
||||
"decoder",
|
||||
"model",
|
||||
],
|
||||
TokenizerVisitor(PhantomData, PhantomData),
|
||||
TokenizerVisitor(PhantomData, PhantomData, PhantomData),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
struct TokenizerVisitor<M, N>(PhantomData<M>, PhantomData<N>);
|
||||
struct TokenizerVisitor<M, N, PT>(PhantomData<M>, PhantomData<N>, PhantomData<PT>);
|
||||
|
||||
impl<'de, M, N> Visitor<'de> for TokenizerVisitor<M, N>
|
||||
impl<'de, M, N, PT> Visitor<'de> for TokenizerVisitor<M, N, PT>
|
||||
where
|
||||
M: Deserialize<'de> + Model,
|
||||
N: Deserialize<'de> + Normalizer,
|
||||
PT: Deserialize<'de> + PreTokenizer,
|
||||
{
|
||||
type Value = Tokenizer<M, N>;
|
||||
type Value = Tokenizer<M, N, PT>;
|
||||
|
||||
fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(fmt, "struct Tokenizer")
|
||||
|
||||
@@ -5,11 +5,12 @@ use tokenizers::normalizers::bert::BertNormalizer;
|
||||
use tokenizers::normalizers::NormalizerWrapper;
|
||||
use tokenizers::pre_tokenizers::bert::BertPreTokenizer;
|
||||
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
||||
use tokenizers::pre_tokenizers::PreTokenizerWrapper;
|
||||
use tokenizers::processors::bert::BertProcessing;
|
||||
use tokenizers::tokenizer::{Model, Tokenizer};
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_empty() -> Tokenizer<BPE, NormalizerWrapper> {
|
||||
pub fn get_empty() -> Tokenizer<BPE, NormalizerWrapper, PreTokenizerWrapper> {
|
||||
Tokenizer::new(BPE::default())
|
||||
}
|
||||
|
||||
@@ -24,11 +25,9 @@ pub fn get_byte_level_bpe() -> BPE {
|
||||
pub fn get_byte_level(
|
||||
add_prefix_space: bool,
|
||||
trim_offsets: bool,
|
||||
) -> Tokenizer<BPE, NormalizerWrapper> {
|
||||
) -> Tokenizer<BPE, NormalizerWrapper, ByteLevel> {
|
||||
let mut tokenizer = Tokenizer::new(get_byte_level_bpe());
|
||||
tokenizer.with_pre_tokenizer(Box::new(
|
||||
ByteLevel::default().add_prefix_space(add_prefix_space),
|
||||
));
|
||||
tokenizer.with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space));
|
||||
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
||||
tokenizer.with_post_processor(Box::new(ByteLevel::default().trim_offsets(trim_offsets)));
|
||||
|
||||
@@ -43,10 +42,10 @@ pub fn get_bert_wordpiece() -> WordPiece {
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_bert() -> Tokenizer<WordPiece, BertNormalizer> {
|
||||
pub fn get_bert() -> Tokenizer<WordPiece, BertNormalizer, BertPreTokenizer> {
|
||||
let mut tokenizer = Tokenizer::new(get_bert_wordpiece());
|
||||
tokenizer.with_normalizer(BertNormalizer::default());
|
||||
tokenizer.with_pre_tokenizer(Box::new(BertPreTokenizer));
|
||||
tokenizer.with_pre_tokenizer(BertPreTokenizer);
|
||||
tokenizer.with_decoder(Box::new(WordPieceDecoder::default()));
|
||||
tokenizer.with_post_processor(Box::new(BertProcessing::new(
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user