Hide generics

This commit is contained in:
Sebastian Puetz
2020-07-31 12:22:13 +02:00
committed by Anthony MOI
parent d62adf7195
commit 42b810488f
9 changed files with 147 additions and 113 deletions

View File

@ -2,17 +2,17 @@ extern crate tokenizers as tk;
use crate::encoding::*; use crate::encoding::*;
use neon::prelude::*; use neon::prelude::*;
use tk::tokenizer::{EncodeInput, Encoding, Tokenizer}; use tk::tokenizer::{EncodeInput, Encoding, TokenizerImpl};
pub struct WorkingTokenizer { pub struct WorkingTokenizer {
_arc: std::sync::Arc<()>, _arc: std::sync::Arc<()>,
ptr: *const Tokenizer, ptr: *const TokenizerImpl,
} }
impl WorkingTokenizer { impl WorkingTokenizer {
/// This is unsafe because the caller must ensure that the given tokenizer /// This is unsafe because the caller must ensure that the given tokenizer
/// wont be modified for the duration of the task. We keep an arc here to let the /// wont be modified for the duration of the task. We keep an arc here to let the
/// caller know when we are done with our pointer on Tokenizer /// caller know when we are done with our pointer on Tokenizer
pub unsafe fn new(tokenizer: &Tokenizer, arc: std::sync::Arc<()>) -> Self { pub unsafe fn new(tokenizer: &TokenizerImpl, arc: std::sync::Arc<()>) -> Self {
WorkingTokenizer { WorkingTokenizer {
_arc: arc, _arc: arc,
ptr: tokenizer as *const _, ptr: tokenizer as *const _,
@ -41,7 +41,7 @@ impl Task for EncodeTask {
EncodeTask::Single(worker, input, add_special_tokens) => { EncodeTask::Single(worker, input, add_special_tokens) => {
let mut input: Option<EncodeInput> = let mut input: Option<EncodeInput> =
unsafe { std::ptr::replace(input as *const _ as *mut _, None) }; unsafe { std::ptr::replace(input as *const _ as *mut _, None) };
let tokenizer: &Tokenizer = unsafe { &*worker.ptr }; let tokenizer: &TokenizerImpl = unsafe { &*worker.ptr };
tokenizer tokenizer
.encode( .encode(
input.take().ok_or("No provided input")?, input.take().ok_or("No provided input")?,
@ -53,7 +53,7 @@ impl Task for EncodeTask {
EncodeTask::Batch(worker, input, add_special_tokens) => { EncodeTask::Batch(worker, input, add_special_tokens) => {
let mut input: Option<Vec<EncodeInput>> = let mut input: Option<Vec<EncodeInput>> =
unsafe { std::ptr::replace(input as *const _ as *mut _, None) }; unsafe { std::ptr::replace(input as *const _ as *mut _, None) };
let tokenizer: &Tokenizer = unsafe { &*worker.ptr }; let tokenizer: &TokenizerImpl = unsafe { &*worker.ptr };
tokenizer tokenizer
.encode_batch( .encode_batch(
input.take().ok_or("No provided input")?, input.take().ok_or("No provided input")?,
@ -120,14 +120,14 @@ impl Task for DecodeTask {
fn perform(&self) -> Result<Self::Output, Self::Error> { fn perform(&self) -> Result<Self::Output, Self::Error> {
match self { match self {
DecodeTask::Single(worker, ids, skip_special_tokens) => { DecodeTask::Single(worker, ids, skip_special_tokens) => {
let tokenizer: &Tokenizer = unsafe { &*worker.ptr }; let tokenizer: &TokenizerImpl = unsafe { &*worker.ptr };
tokenizer tokenizer
.decode(ids.to_vec(), *skip_special_tokens) .decode(ids.to_vec(), *skip_special_tokens)
.map_err(|e| format!("{}", e)) .map_err(|e| format!("{}", e))
.map(DecodeOutput::Single) .map(DecodeOutput::Single)
} }
DecodeTask::Batch(worker, ids, skip_special_tokens) => { DecodeTask::Batch(worker, ids, skip_special_tokens) => {
let tokenizer: &Tokenizer = unsafe { &*worker.ptr }; let tokenizer: &TokenizerImpl = unsafe { &*worker.ptr };
tokenizer tokenizer
.decode_batch(ids.to_vec(), *skip_special_tokens) .decode_batch(ids.to_vec(), *skip_special_tokens)
.map_err(|e| format!("{}", e)) .map_err(|e| format!("{}", e))

View File

@ -345,7 +345,7 @@ pub struct PaddingParams(#[serde(with = "PaddingParamsDef")] pub tk::PaddingPara
/// Tokenizer /// Tokenizer
pub struct Tokenizer { pub struct Tokenizer {
tokenizer: tk::Tokenizer, tokenizer: tk::TokenizerImpl,
/// Whether we have a running task. We keep this to make sure we never /// Whether we have a running task. We keep this to make sure we never
/// modify the underlying tokenizer while a task is running /// modify the underlying tokenizer while a task is running
@ -1016,7 +1016,7 @@ declare_types! {
pub fn tokenizer_from_string(mut cx: FunctionContext) -> JsResult<JsTokenizer> { pub fn tokenizer_from_string(mut cx: FunctionContext) -> JsResult<JsTokenizer> {
let s = cx.extract::<String>(0)?; let s = cx.extract::<String>(0)?;
let tokenizer: tk::tokenizer::Tokenizer = s let tokenizer: tk::tokenizer::TokenizerImpl = s
.parse() .parse()
.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?; .map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?;
@ -1030,7 +1030,7 @@ pub fn tokenizer_from_string(mut cx: FunctionContext) -> JsResult<JsTokenizer> {
pub fn tokenizer_from_file(mut cx: FunctionContext) -> JsResult<JsTokenizer> { pub fn tokenizer_from_file(mut cx: FunctionContext) -> JsResult<JsTokenizer> {
let s = cx.extract::<String>(0)?; let s = cx.extract::<String>(0)?;
let tokenizer = tk::tokenizer::Tokenizer::from_file(s) let tokenizer = tk::tokenizer::TokenizerImpl::from_file(s)
.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?; .map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?;
let mut js_tokenizer = JsTokenizer::new::<_, JsTokenizer, _>(&mut cx, vec![])?; let mut js_tokenizer = JsTokenizer::new::<_, JsTokenizer, _>(&mut cx, vec![])?;

View File

@ -7,7 +7,7 @@ use pyo3::types::*;
use pyo3::PyObjectProtocol; use pyo3::PyObjectProtocol;
use tk::models::bpe::BPE; use tk::models::bpe::BPE;
use tk::tokenizer::{ use tk::tokenizer::{
PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, Tokenizer, TruncationParams, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, TruncationParams,
TruncationStrategy, TruncationStrategy,
}; };
use tokenizers as tk; use tokenizers as tk;
@ -267,16 +267,16 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
} }
} }
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor, PyDecoder>; type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor, PyDecoder>;
#[pyclass(dict, module = "tokenizers", name=Tokenizer)] #[pyclass(dict, module = "tokenizers", name=Tokenizer)]
#[derive(Clone)] #[derive(Clone)]
pub struct PyTokenizer { pub struct PyTokenizer {
tokenizer: TokenizerImpl, tokenizer: Tokenizer,
} }
impl PyTokenizer { impl PyTokenizer {
fn new(tokenizer: TokenizerImpl) -> Self { fn new(tokenizer: Tokenizer) -> Self {
PyTokenizer { tokenizer } PyTokenizer { tokenizer }
} }
@ -331,7 +331,7 @@ impl PyTokenizer {
#[staticmethod] #[staticmethod]
fn from_file(path: &str) -> PyResult<Self> { fn from_file(path: &str) -> PyResult<Self> {
let tokenizer: PyResult<_> = ToPyResult(Tokenizer::from_file(path)).into(); let tokenizer: PyResult<_> = ToPyResult(TokenizerImpl::from_file(path)).into();
Ok(Self { Ok(Self {
tokenizer: tokenizer?, tokenizer: tokenizer?,
}) })

View File

@ -6,20 +6,16 @@ use std::fs::File;
use std::io::{BufRead, BufReader}; use std::io::{BufRead, BufReader};
use std::path::Path; use std::path::Path;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokenizers::decoders::DecoderWrapper;
use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; use tokenizers::models::bpe::{BpeTrainerBuilder, BPE};
use tokenizers::normalizers::NormalizerWrapper; use tokenizers::models::{ModelWrapper, TrainerWrapper};
use tokenizers::pre_tokenizers::byte_level::ByteLevel; use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::pre_tokenizers::whitespace::Whitespace; use tokenizers::pre_tokenizers::whitespace::Whitespace;
use tokenizers::processors::PostProcessorWrapper; use tokenizers::tokenizer::{AddedToken, EncodeInput, Trainer};
use tokenizers::tokenizer::{AddedToken, EncodeInput, Tokenizer, Trainer}; use tokenizers::Tokenizer;
use tokenizers::{Decoder, Model, Normalizer, PostProcessor, PreTokenizer};
static BATCH_SIZE: usize = 1_000; static BATCH_SIZE: usize = 1_000;
fn create_gpt2_tokenizer( fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
bpe: BPE,
) -> Tokenizer<BPE, NormalizerWrapper, ByteLevel, PostProcessorWrapper, ByteLevel> {
let mut tokenizer = Tokenizer::new(bpe); let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(ByteLevel::default()); tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_decoder(ByteLevel::default()); tokenizer.with_decoder(ByteLevel::default());
@ -28,18 +24,7 @@ fn create_gpt2_tokenizer(
tokenizer tokenizer
} }
fn iter_bench_encode<M, N, PT, PP, D>( fn iter_bench_encode(iters: u64, tokenizer: &Tokenizer, lines: &[EncodeInput]) -> Duration {
iters: u64,
tokenizer: &Tokenizer<M, N, PT, PP, D>,
lines: &[EncodeInput],
) -> Duration
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
PP: PostProcessor,
D: Decoder,
{
let mut duration = Duration::new(0, 0); let mut duration = Duration::new(0, 0);
let mut line_index: usize = 0; let mut line_index: usize = 0;
for _i in 0..iters { for _i in 0..iters {
@ -54,18 +39,11 @@ where
duration duration
} }
fn iter_bench_encode_batch<M, N, PT, PP, D>( fn iter_bench_encode_batch(
iters: u64, iters: u64,
tokenizer: &Tokenizer<M, N, PT, PP, D>, tokenizer: &Tokenizer,
batches: &[Vec<EncodeInput>], batches: &[Vec<EncodeInput>],
) -> Duration ) -> Duration {
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
PP: PostProcessor,
D: Decoder,
{
let mut duration = Duration::new(0, 0); let mut duration = Duration::new(0, 0);
let mut batch_index: usize = 0; let mut batch_index: usize = 0;
for _i in 0..iters { for _i in 0..iters {
@ -119,28 +97,30 @@ fn bench_gpt2(c: &mut Criterion) {
}); });
} }
fn iter_bench_train<T, M, PT>( fn iter_bench_train<T>(
iters: u64, iters: u64,
mut tokenizer: Tokenizer<M, NormalizerWrapper, PT, PostProcessorWrapper, DecoderWrapper>, tokenizer: Tokenizer,
trainer: &T, trainer: &T,
files: Vec<String>, files: Vec<String>,
) -> Duration ) -> Duration
where where
M: Model, T: Trainer<Model = ModelWrapper>,
PT: PreTokenizer,
T: Trainer<Model = M>,
{ {
let mut tokenizer = tokenizer.into_inner();
let mut duration = Duration::new(0, 0); let mut duration = Duration::new(0, 0);
for _i in 0..iters { for _i in 0..iters {
let start = Instant::now(); let start = Instant::now();
tokenizer = black_box(tokenizer.train(trainer, files.clone())).unwrap(); tokenizer = black_box(tokenizer.train(trainer, files.clone()).unwrap());
duration = duration.checked_add(start.elapsed()).unwrap(); duration = duration.checked_add(start.elapsed()).unwrap();
} }
duration duration
} }
fn bench_train(c: &mut Criterion) { fn bench_train(c: &mut Criterion) {
let trainer = BpeTrainerBuilder::default().show_progress(false).build(); let trainer: TrainerWrapper = BpeTrainerBuilder::default()
.show_progress(false)
.build()
.into();
c.bench_function("BPE Train vocabulary (small)", |b| { c.bench_function("BPE Train vocabulary (small)", |b| {
b.iter_custom(|iters| { b.iter_custom(|iters| {
let mut tokenizer = Tokenizer::new(BPE::default()); let mut tokenizer = Tokenizer::new(BPE::default());

View File

@ -5,10 +5,9 @@
use clap::{App, AppSettings, Arg, ArgMatches, SubCommand}; use clap::{App, AppSettings, Arg, ArgMatches, SubCommand};
use std::io::{self, BufRead, Write}; use std::io::{self, BufRead, Write};
use tokenizers::models::bpe::BPE; use tokenizers::models::bpe::BPE;
use tokenizers::normalizers::NormalizerWrapper;
use tokenizers::pre_tokenizers::byte_level::ByteLevel; use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::processors::PostProcessorWrapper; use tokenizers::tokenizer::{AddedToken, Result};
use tokenizers::tokenizer::{AddedToken, Result, Tokenizer}; use tokenizers::Tokenizer;
fn shell(matches: &ArgMatches) -> Result<()> { fn shell(matches: &ArgMatches) -> Result<()> {
let vocab = matches let vocab = matches
@ -19,8 +18,7 @@ fn shell(matches: &ArgMatches) -> Result<()> {
.expect("Must give a merges.txt file"); .expect("Must give a merges.txt file");
let bpe = BPE::from_files(vocab, merges).build()?; let bpe = BPE::from_files(vocab, merges).build()?;
let mut tokenizer = let mut tokenizer = Tokenizer::new(bpe);
Tokenizer::<_, NormalizerWrapper, ByteLevel, PostProcessorWrapper, ByteLevel>::new(bpe);
tokenizer.with_pre_tokenizer(ByteLevel::default()); tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_decoder(ByteLevel::default()); tokenizer.with_decoder(ByteLevel::default());

View File

@ -39,11 +39,7 @@
//! .unk_token("[UNK]".into()) //! .unk_token("[UNK]".into())
//! .build()?; //! .build()?;
//! //!
//! let mut tokenizer = Tokenizer::<_, //! let mut tokenizer = Tokenizer::new(bpe);
//! NormalizerWrapper,
//! PreTokenizerWrapper,
//! PostProcessorWrapper,
//! DecoderWrapper>::new(bpe);
//! //!
//! let encoding = tokenizer.encode("Hey there!", false)?; //! let encoding = tokenizer.encode("Hey there!", false)?;
//! println!("{:?}", encoding.get_tokens()); //! println!("{:?}", encoding.get_tokens());

View File

@ -15,6 +15,7 @@ use std::{
fs::File, fs::File,
io::prelude::*, io::prelude::*,
io::BufReader, io::BufReader,
ops::{Deref, DerefMut},
path::{Path, PathBuf}, path::{Path, PathBuf},
}; };
@ -23,6 +24,11 @@ use serde::de::DeserializeOwned;
use serde::export::Formatter; use serde::export::Formatter;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::decoders::DecoderWrapper;
use crate::models::ModelWrapper;
use crate::normalizers::NormalizerWrapper;
use crate::pre_tokenizers::PreTokenizerWrapper;
use crate::processors::PostProcessorWrapper;
use crate::tokenizer::normalizer::Range; use crate::tokenizer::normalizer::Range;
use crate::utils::parallelism::*; use crate::utils::parallelism::*;
@ -264,11 +270,11 @@ where
/// Convert the TokenizerBuilder to a Tokenizer. /// Convert the TokenizerBuilder to a Tokenizer.
/// ///
/// Conversion fails if the `model` is missing. /// Conversion fails if the `model` is missing.
pub fn build(self) -> Result<Tokenizer<M, N, PT, PP, D>> { pub fn build(self) -> Result<TokenizerImpl<M, N, PT, PP, D>> {
let model = self let model = self
.model .model
.ok_or_else(|| Box::new(BuilderError("Model missing.".into())))?; .ok_or_else(|| Box::new(BuilderError("Model missing.".into())))?;
Ok(Tokenizer { Ok(TokenizerImpl {
normalizer: self.normalizer, normalizer: self.normalizer,
pre_tokenizer: self.pre_tokenizer, pre_tokenizer: self.pre_tokenizer,
model, model,
@ -324,9 +330,82 @@ where
} }
} }
#[derive(Serialize, Deserialize)]
pub struct Tokenizer(
TokenizerImpl<
ModelWrapper,
NormalizerWrapper,
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
>,
);
impl Tokenizer {
/// Construct a new Tokenizer based on the model.
pub fn new(model: impl Into<ModelWrapper>) -> Self {
Self(TokenizerImpl::new(model.into()))
}
/// Unwrap the TokenizerImpl.
pub fn into_inner(
self,
) -> TokenizerImpl<
ModelWrapper,
NormalizerWrapper,
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
> {
self.0
}
}
impl<M, N, PT, PP, D> From<TokenizerImpl<M, N, PT, PP, D>> for Tokenizer
where
M: Into<ModelWrapper>,
N: Into<NormalizerWrapper>,
PT: Into<PreTokenizerWrapper>,
PP: Into<PostProcessorWrapper>,
D: Into<DecoderWrapper>,
{
fn from(t: TokenizerImpl<M, N, PT, PP, D>) -> Self {
Self(TokenizerImpl {
model: t.model.into(),
normalizer: t.normalizer.map(Into::into),
pre_tokenizer: t.pre_tokenizer.map(Into::into),
post_processor: t.post_processor.map(Into::into),
decoder: t.decoder.map(Into::into),
added_vocabulary: t.added_vocabulary,
padding: t.padding,
truncation: t.truncation,
})
}
}
impl Deref for Tokenizer {
type Target = TokenizerImpl<
ModelWrapper,
NormalizerWrapper,
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Tokenizer {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
/// A `Tokenizer` is capable of encoding/decoding any text. /// A `Tokenizer` is capable of encoding/decoding any text.
#[derive(Clone)] #[derive(Clone)]
pub struct Tokenizer<M, N, PT, PP, D> { pub struct TokenizerImpl<M, N, PT, PP, D> {
// Tokenizer parts // Tokenizer parts
normalizer: Option<N>, normalizer: Option<N>,
pre_tokenizer: Option<PT>, pre_tokenizer: Option<PT>,
@ -342,7 +421,7 @@ pub struct Tokenizer<M, N, PT, PP, D> {
padding: Option<PaddingParams>, padding: Option<PaddingParams>,
} }
impl<M, N, PT, PP, D> Tokenizer<M, N, PT, PP, D> impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
where where
M: Model, M: Model,
N: Normalizer, N: Normalizer,
@ -352,7 +431,7 @@ where
{ {
/// Instantiate a new Tokenizer, with the given Model /// Instantiate a new Tokenizer, with the given Model
pub fn new(model: M) -> Self { pub fn new(model: M) -> Self {
Tokenizer { TokenizerImpl {
normalizer: None, normalizer: None,
pre_tokenizer: None, pre_tokenizer: None,
model, model,
@ -367,8 +446,8 @@ where
} }
/// Set the normalizer /// Set the normalizer
pub fn with_normalizer(&mut self, normalizer: N) -> &Self { pub fn with_normalizer(&mut self, normalizer: impl Into<N>) -> &Self {
self.normalizer = Some(normalizer); self.normalizer = Some(normalizer.into());
self self
} }
@ -378,8 +457,8 @@ where
} }
/// Set the pre tokenizer /// Set the pre tokenizer
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: PT) -> &Self { pub fn with_pre_tokenizer(&mut self, pre_tokenizer: impl Into<PT>) -> &Self {
self.pre_tokenizer = Some(pre_tokenizer); self.pre_tokenizer = Some(pre_tokenizer.into());
self self
} }
@ -389,8 +468,8 @@ where
} }
/// Set the post processor /// Set the post processor
pub fn with_post_processor(&mut self, post_processor: PP) -> &Self { pub fn with_post_processor(&mut self, post_processor: impl Into<PP>) -> &Self {
self.post_processor = Some(post_processor); self.post_processor = Some(post_processor.into());
self self
} }
@ -400,8 +479,8 @@ where
} }
/// Set the decoder /// Set the decoder
pub fn with_decoder(&mut self, decoder: D) -> &Self { pub fn with_decoder(&mut self, decoder: impl Into<D>) -> &Self {
self.decoder = Some(decoder); self.decoder = Some(decoder.into());
self self
} }
@ -411,8 +490,8 @@ where
} }
/// Set the model /// Set the model
pub fn with_model(&mut self, model: M) -> &mut Self { pub fn with_model(&mut self, model: impl Into<M>) -> &mut Self {
self.model = model; self.model = model.into();
self self
} }
@ -587,13 +666,7 @@ where
/// # use tokenizers::pre_tokenizers::PreTokenizerWrapper; /// # use tokenizers::pre_tokenizers::PreTokenizerWrapper;
/// # use tokenizers::processors::PostProcessorWrapper; /// # use tokenizers::processors::PostProcessorWrapper;
/// # use tokenizers::decoders::DecoderWrapper; /// # use tokenizers::decoders::DecoderWrapper;
/// # let mut tokenizer = /// # let mut tokenizer = Tokenizer::new(BPE::default());
/// # Tokenizer::<_,
/// # NormalizerWrapper,
/// # PreTokenizerWrapper,
/// # PostProcessorWrapper,
/// # DecoderWrapper
/// # >::new(BPE::default());
/// # /// #
/// // Sequences: /// // Sequences:
/// tokenizer.encode("Single sequence", false); /// tokenizer.encode("Single sequence", false);
@ -772,7 +845,7 @@ where
self, self,
trainer: &T, trainer: &T,
files: Vec<String>, files: Vec<String>,
) -> Result<Tokenizer<TM, N, PT, PP, D>> ) -> Result<TokenizerImpl<TM, N, PT, PP, D>>
where where
T: Trainer<Model = TM>, T: Trainer<Model = TM>,
TM: Model, TM: Model,
@ -780,7 +853,7 @@ where
let words = self.word_count(trainer, files)?; let words = self.word_count(trainer, files)?;
let (model, special_tokens) = trainer.train(words)?; let (model, special_tokens) = trainer.train(words)?;
let mut new_tok = Tokenizer { let mut new_tok = TokenizerImpl {
normalizer: self.normalizer, normalizer: self.normalizer,
pre_tokenizer: self.pre_tokenizer, pre_tokenizer: self.pre_tokenizer,
model, model,
@ -932,7 +1005,7 @@ where
} }
} }
impl<M, N, PT, PP, D> std::str::FromStr for Tokenizer<M, N, PT, PP, D> impl<M, N, PT, PP, D> std::str::FromStr for TokenizerImpl<M, N, PT, PP, D>
where where
M: for<'de> Deserialize<'de> + Model, M: for<'de> Deserialize<'de> + Model,
N: for<'de> Deserialize<'de> + Normalizer, N: for<'de> Deserialize<'de> + Normalizer,
@ -947,7 +1020,7 @@ where
} }
} }
impl<M, N, PT, PP, D> Tokenizer<M, N, PT, PP, D> impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
where where
M: DeserializeOwned + Model, M: DeserializeOwned + Model,
N: DeserializeOwned + Normalizer, N: DeserializeOwned + Normalizer,
@ -963,7 +1036,7 @@ where
} }
} }
impl<M, N, PT, PP, D> Tokenizer<M, N, PT, PP, D> impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
where where
M: Serialize, M: Serialize,
N: Serialize, N: Serialize,

View File

@ -7,12 +7,12 @@ use serde::{
Deserialize, Deserializer, Serialize, Serializer, Deserialize, Deserializer, Serialize, Serializer,
}; };
use super::{added_vocabulary::AddedTokenWithId, Tokenizer}; use super::{added_vocabulary::AddedTokenWithId, TokenizerImpl};
use crate::{Decoder, Model, Normalizer, PostProcessor, PreTokenizer, TokenizerBuilder}; use crate::{Decoder, Model, Normalizer, PostProcessor, PreTokenizer, TokenizerBuilder};
static SERIALIZATION_VERSION: &str = "1.0"; static SERIALIZATION_VERSION: &str = "1.0";
impl<M, N, PT, PP, D> Serialize for Tokenizer<M, N, PT, PP, D> impl<M, N, PT, PP, D> Serialize for TokenizerImpl<M, N, PT, PP, D>
where where
M: Serialize, M: Serialize,
N: Serialize, N: Serialize,
@ -47,7 +47,7 @@ where
} }
} }
impl<'de, M, N, PT, PP, D> Deserialize<'de> for Tokenizer<M, N, PT, PP, D> impl<'de, M, N, PT, PP, D> Deserialize<'de> for TokenizerImpl<M, N, PT, PP, D>
where where
M: Deserialize<'de> + Model, M: Deserialize<'de> + Model,
N: Deserialize<'de> + Normalizer, N: Deserialize<'de> + Normalizer,
@ -99,7 +99,7 @@ where
PP: Deserialize<'de> + PostProcessor, PP: Deserialize<'de> + PostProcessor,
D: Deserialize<'de> + Decoder, D: Deserialize<'de> + Decoder,
{ {
type Value = Tokenizer<M, N, PT, PP, D>; type Value = TokenizerImpl<M, N, PT, PP, D>;
fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(fmt, "struct Tokenizer") write!(fmt, "struct Tokenizer")

View File

@ -1,19 +1,14 @@
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder; use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
use tokenizers::decoders::DecoderWrapper;
use tokenizers::models::bpe::BPE; use tokenizers::models::bpe::BPE;
use tokenizers::models::wordpiece::WordPiece; use tokenizers::models::wordpiece::WordPiece;
use tokenizers::normalizers::bert::BertNormalizer; use tokenizers::normalizers::bert::BertNormalizer;
use tokenizers::normalizers::NormalizerWrapper;
use tokenizers::pre_tokenizers::bert::BertPreTokenizer; use tokenizers::pre_tokenizers::bert::BertPreTokenizer;
use tokenizers::pre_tokenizers::byte_level::ByteLevel; use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::pre_tokenizers::PreTokenizerWrapper;
use tokenizers::processors::bert::BertProcessing; use tokenizers::processors::bert::BertProcessing;
use tokenizers::processors::PostProcessorWrapper;
use tokenizers::tokenizer::{Model, Tokenizer}; use tokenizers::tokenizer::{Model, Tokenizer};
#[allow(dead_code)] #[allow(dead_code)]
pub fn get_empty( pub fn get_empty() -> Tokenizer {
) -> Tokenizer<BPE, NormalizerWrapper, PreTokenizerWrapper, PostProcessorWrapper, DecoderWrapper> {
Tokenizer::new(BPE::default()) Tokenizer::new(BPE::default())
} }
@ -25,10 +20,7 @@ pub fn get_byte_level_bpe() -> BPE {
} }
#[allow(dead_code)] #[allow(dead_code)]
pub fn get_byte_level( pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
add_prefix_space: bool,
trim_offsets: bool,
) -> Tokenizer<BPE, NormalizerWrapper, ByteLevel, ByteLevel, ByteLevel> {
let mut tokenizer = Tokenizer::new(get_byte_level_bpe()); let mut tokenizer = Tokenizer::new(get_byte_level_bpe());
tokenizer.with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space)); tokenizer.with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space));
tokenizer.with_decoder(ByteLevel::default()); tokenizer.with_decoder(ByteLevel::default());
@ -45,21 +37,16 @@ pub fn get_bert_wordpiece() -> WordPiece {
} }
#[allow(dead_code)] #[allow(dead_code)]
pub fn get_bert( pub fn get_bert() -> Tokenizer {
) -> Tokenizer<WordPiece, BertNormalizer, BertPreTokenizer, BertProcessing, WordPieceDecoder> {
let mut tokenizer = Tokenizer::new(get_bert_wordpiece()); let mut tokenizer = Tokenizer::new(get_bert_wordpiece());
tokenizer.with_normalizer(BertNormalizer::default()); tokenizer.with_normalizer(BertNormalizer::default());
tokenizer.with_pre_tokenizer(BertPreTokenizer); tokenizer.with_pre_tokenizer(BertPreTokenizer);
tokenizer.with_decoder(WordPieceDecoder::default()); tokenizer.with_decoder(WordPieceDecoder::default());
let sep = tokenizer.get_model().token_to_id("[SEP]").unwrap();
let cls = tokenizer.get_model().token_to_id("[CLS]").unwrap();
tokenizer.with_post_processor(BertProcessing::new( tokenizer.with_post_processor(BertProcessing::new(
( (String::from("[SEP]"), sep),
String::from("[SEP]"), (String::from("[CLS]"), cls),
tokenizer.get_model().token_to_id("[SEP]").unwrap(),
),
(
String::from("[CLS]"),
tokenizer.get_model().token_to_id("[CLS]").unwrap(),
),
)); ));
tokenizer tokenizer