Hide generics

This commit is contained in:
Sebastian Puetz
2020-07-31 12:22:13 +02:00
committed by Anthony MOI
parent d62adf7195
commit 42b810488f
9 changed files with 147 additions and 113 deletions

View File

@ -2,17 +2,17 @@ extern crate tokenizers as tk;
use crate::encoding::*;
use neon::prelude::*;
use tk::tokenizer::{EncodeInput, Encoding, Tokenizer};
use tk::tokenizer::{EncodeInput, Encoding, TokenizerImpl};
pub struct WorkingTokenizer {
_arc: std::sync::Arc<()>,
ptr: *const Tokenizer,
ptr: *const TokenizerImpl,
}
impl WorkingTokenizer {
/// This is unsafe because the caller must ensure that the given tokenizer
/// wont be modified for the duration of the task. We keep an arc here to let the
/// caller know when we are done with our pointer on Tokenizer
pub unsafe fn new(tokenizer: &Tokenizer, arc: std::sync::Arc<()>) -> Self {
pub unsafe fn new(tokenizer: &TokenizerImpl, arc: std::sync::Arc<()>) -> Self {
WorkingTokenizer {
_arc: arc,
ptr: tokenizer as *const _,
@ -41,7 +41,7 @@ impl Task for EncodeTask {
EncodeTask::Single(worker, input, add_special_tokens) => {
let mut input: Option<EncodeInput> =
unsafe { std::ptr::replace(input as *const _ as *mut _, None) };
let tokenizer: &Tokenizer = unsafe { &*worker.ptr };
let tokenizer: &TokenizerImpl = unsafe { &*worker.ptr };
tokenizer
.encode(
input.take().ok_or("No provided input")?,
@ -53,7 +53,7 @@ impl Task for EncodeTask {
EncodeTask::Batch(worker, input, add_special_tokens) => {
let mut input: Option<Vec<EncodeInput>> =
unsafe { std::ptr::replace(input as *const _ as *mut _, None) };
let tokenizer: &Tokenizer = unsafe { &*worker.ptr };
let tokenizer: &TokenizerImpl = unsafe { &*worker.ptr };
tokenizer
.encode_batch(
input.take().ok_or("No provided input")?,
@ -120,14 +120,14 @@ impl Task for DecodeTask {
fn perform(&self) -> Result<Self::Output, Self::Error> {
match self {
DecodeTask::Single(worker, ids, skip_special_tokens) => {
let tokenizer: &Tokenizer = unsafe { &*worker.ptr };
let tokenizer: &TokenizerImpl = unsafe { &*worker.ptr };
tokenizer
.decode(ids.to_vec(), *skip_special_tokens)
.map_err(|e| format!("{}", e))
.map(DecodeOutput::Single)
}
DecodeTask::Batch(worker, ids, skip_special_tokens) => {
let tokenizer: &Tokenizer = unsafe { &*worker.ptr };
let tokenizer: &TokenizerImpl = unsafe { &*worker.ptr };
tokenizer
.decode_batch(ids.to_vec(), *skip_special_tokens)
.map_err(|e| format!("{}", e))

View File

@ -345,7 +345,7 @@ pub struct PaddingParams(#[serde(with = "PaddingParamsDef")] pub tk::PaddingPara
/// Tokenizer
pub struct Tokenizer {
tokenizer: tk::Tokenizer,
tokenizer: tk::TokenizerImpl,
/// Whether we have a running task. We keep this to make sure we never
/// modify the underlying tokenizer while a task is running
@ -1016,7 +1016,7 @@ declare_types! {
pub fn tokenizer_from_string(mut cx: FunctionContext) -> JsResult<JsTokenizer> {
let s = cx.extract::<String>(0)?;
let tokenizer: tk::tokenizer::Tokenizer = s
let tokenizer: tk::tokenizer::TokenizerImpl = s
.parse()
.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?;
@ -1030,7 +1030,7 @@ pub fn tokenizer_from_string(mut cx: FunctionContext) -> JsResult<JsTokenizer> {
pub fn tokenizer_from_file(mut cx: FunctionContext) -> JsResult<JsTokenizer> {
let s = cx.extract::<String>(0)?;
let tokenizer = tk::tokenizer::Tokenizer::from_file(s)
let tokenizer = tk::tokenizer::TokenizerImpl::from_file(s)
.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?;
let mut js_tokenizer = JsTokenizer::new::<_, JsTokenizer, _>(&mut cx, vec![])?;

View File

@ -7,7 +7,7 @@ use pyo3::types::*;
use pyo3::PyObjectProtocol;
use tk::models::bpe::BPE;
use tk::tokenizer::{
PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, Tokenizer, TruncationParams,
PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, TruncationParams,
TruncationStrategy,
};
use tokenizers as tk;
@ -267,16 +267,16 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
}
}
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor, PyDecoder>;
type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor, PyDecoder>;
#[pyclass(dict, module = "tokenizers", name=Tokenizer)]
#[derive(Clone)]
pub struct PyTokenizer {
tokenizer: TokenizerImpl,
tokenizer: Tokenizer,
}
impl PyTokenizer {
fn new(tokenizer: TokenizerImpl) -> Self {
fn new(tokenizer: Tokenizer) -> Self {
PyTokenizer { tokenizer }
}
@ -331,7 +331,7 @@ impl PyTokenizer {
#[staticmethod]
fn from_file(path: &str) -> PyResult<Self> {
let tokenizer: PyResult<_> = ToPyResult(Tokenizer::from_file(path)).into();
let tokenizer: PyResult<_> = ToPyResult(TokenizerImpl::from_file(path)).into();
Ok(Self {
tokenizer: tokenizer?,
})

View File

@ -6,20 +6,16 @@ use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::time::{Duration, Instant};
use tokenizers::decoders::DecoderWrapper;
use tokenizers::models::bpe::{BpeTrainerBuilder, BPE};
use tokenizers::normalizers::NormalizerWrapper;
use tokenizers::models::{ModelWrapper, TrainerWrapper};
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::pre_tokenizers::whitespace::Whitespace;
use tokenizers::processors::PostProcessorWrapper;
use tokenizers::tokenizer::{AddedToken, EncodeInput, Tokenizer, Trainer};
use tokenizers::{Decoder, Model, Normalizer, PostProcessor, PreTokenizer};
use tokenizers::tokenizer::{AddedToken, EncodeInput, Trainer};
use tokenizers::Tokenizer;
static BATCH_SIZE: usize = 1_000;
fn create_gpt2_tokenizer(
bpe: BPE,
) -> Tokenizer<BPE, NormalizerWrapper, ByteLevel, PostProcessorWrapper, ByteLevel> {
fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_decoder(ByteLevel::default());
@ -28,18 +24,7 @@ fn create_gpt2_tokenizer(
tokenizer
}
fn iter_bench_encode<M, N, PT, PP, D>(
iters: u64,
tokenizer: &Tokenizer<M, N, PT, PP, D>,
lines: &[EncodeInput],
) -> Duration
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
PP: PostProcessor,
D: Decoder,
{
fn iter_bench_encode(iters: u64, tokenizer: &Tokenizer, lines: &[EncodeInput]) -> Duration {
let mut duration = Duration::new(0, 0);
let mut line_index: usize = 0;
for _i in 0..iters {
@ -54,18 +39,11 @@ where
duration
}
fn iter_bench_encode_batch<M, N, PT, PP, D>(
fn iter_bench_encode_batch(
iters: u64,
tokenizer: &Tokenizer<M, N, PT, PP, D>,
tokenizer: &Tokenizer,
batches: &[Vec<EncodeInput>],
) -> Duration
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
PP: PostProcessor,
D: Decoder,
{
) -> Duration {
let mut duration = Duration::new(0, 0);
let mut batch_index: usize = 0;
for _i in 0..iters {
@ -119,28 +97,30 @@ fn bench_gpt2(c: &mut Criterion) {
});
}
fn iter_bench_train<T, M, PT>(
fn iter_bench_train<T>(
iters: u64,
mut tokenizer: Tokenizer<M, NormalizerWrapper, PT, PostProcessorWrapper, DecoderWrapper>,
tokenizer: Tokenizer,
trainer: &T,
files: Vec<String>,
) -> Duration
where
M: Model,
PT: PreTokenizer,
T: Trainer<Model = M>,
T: Trainer<Model = ModelWrapper>,
{
let mut tokenizer = tokenizer.into_inner();
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
tokenizer = black_box(tokenizer.train(trainer, files.clone())).unwrap();
tokenizer = black_box(tokenizer.train(trainer, files.clone()).unwrap());
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
fn bench_train(c: &mut Criterion) {
let trainer = BpeTrainerBuilder::default().show_progress(false).build();
let trainer: TrainerWrapper = BpeTrainerBuilder::default()
.show_progress(false)
.build()
.into();
c.bench_function("BPE Train vocabulary (small)", |b| {
b.iter_custom(|iters| {
let mut tokenizer = Tokenizer::new(BPE::default());

View File

@ -5,10 +5,9 @@
use clap::{App, AppSettings, Arg, ArgMatches, SubCommand};
use std::io::{self, BufRead, Write};
use tokenizers::models::bpe::BPE;
use tokenizers::normalizers::NormalizerWrapper;
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::processors::PostProcessorWrapper;
use tokenizers::tokenizer::{AddedToken, Result, Tokenizer};
use tokenizers::tokenizer::{AddedToken, Result};
use tokenizers::Tokenizer;
fn shell(matches: &ArgMatches) -> Result<()> {
let vocab = matches
@ -19,8 +18,7 @@ fn shell(matches: &ArgMatches) -> Result<()> {
.expect("Must give a merges.txt file");
let bpe = BPE::from_files(vocab, merges).build()?;
let mut tokenizer =
Tokenizer::<_, NormalizerWrapper, ByteLevel, PostProcessorWrapper, ByteLevel>::new(bpe);
let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_decoder(ByteLevel::default());

View File

@ -39,11 +39,7 @@
//! .unk_token("[UNK]".into())
//! .build()?;
//!
//! let mut tokenizer = Tokenizer::<_,
//! NormalizerWrapper,
//! PreTokenizerWrapper,
//! PostProcessorWrapper,
//! DecoderWrapper>::new(bpe);
//! let mut tokenizer = Tokenizer::new(bpe);
//!
//! let encoding = tokenizer.encode("Hey there!", false)?;
//! println!("{:?}", encoding.get_tokens());

View File

@ -15,6 +15,7 @@ use std::{
fs::File,
io::prelude::*,
io::BufReader,
ops::{Deref, DerefMut},
path::{Path, PathBuf},
};
@ -23,6 +24,11 @@ use serde::de::DeserializeOwned;
use serde::export::Formatter;
use serde::{Deserialize, Serialize};
use crate::decoders::DecoderWrapper;
use crate::models::ModelWrapper;
use crate::normalizers::NormalizerWrapper;
use crate::pre_tokenizers::PreTokenizerWrapper;
use crate::processors::PostProcessorWrapper;
use crate::tokenizer::normalizer::Range;
use crate::utils::parallelism::*;
@ -264,11 +270,11 @@ where
/// Convert the TokenizerBuilder to a Tokenizer.
///
/// Conversion fails if the `model` is missing.
pub fn build(self) -> Result<Tokenizer<M, N, PT, PP, D>> {
pub fn build(self) -> Result<TokenizerImpl<M, N, PT, PP, D>> {
let model = self
.model
.ok_or_else(|| Box::new(BuilderError("Model missing.".into())))?;
Ok(Tokenizer {
Ok(TokenizerImpl {
normalizer: self.normalizer,
pre_tokenizer: self.pre_tokenizer,
model,
@ -324,9 +330,82 @@ where
}
}
#[derive(Serialize, Deserialize)]
pub struct Tokenizer(
TokenizerImpl<
ModelWrapper,
NormalizerWrapper,
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
>,
);
impl Tokenizer {
/// Construct a new Tokenizer based on the model.
pub fn new(model: impl Into<ModelWrapper>) -> Self {
Self(TokenizerImpl::new(model.into()))
}
/// Unwrap the TokenizerImpl.
pub fn into_inner(
self,
) -> TokenizerImpl<
ModelWrapper,
NormalizerWrapper,
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
> {
self.0
}
}
impl<M, N, PT, PP, D> From<TokenizerImpl<M, N, PT, PP, D>> for Tokenizer
where
M: Into<ModelWrapper>,
N: Into<NormalizerWrapper>,
PT: Into<PreTokenizerWrapper>,
PP: Into<PostProcessorWrapper>,
D: Into<DecoderWrapper>,
{
fn from(t: TokenizerImpl<M, N, PT, PP, D>) -> Self {
Self(TokenizerImpl {
model: t.model.into(),
normalizer: t.normalizer.map(Into::into),
pre_tokenizer: t.pre_tokenizer.map(Into::into),
post_processor: t.post_processor.map(Into::into),
decoder: t.decoder.map(Into::into),
added_vocabulary: t.added_vocabulary,
padding: t.padding,
truncation: t.truncation,
})
}
}
impl Deref for Tokenizer {
type Target = TokenizerImpl<
ModelWrapper,
NormalizerWrapper,
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Tokenizer {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
/// A `Tokenizer` is capable of encoding/decoding any text.
#[derive(Clone)]
pub struct Tokenizer<M, N, PT, PP, D> {
pub struct TokenizerImpl<M, N, PT, PP, D> {
// Tokenizer parts
normalizer: Option<N>,
pre_tokenizer: Option<PT>,
@ -342,7 +421,7 @@ pub struct Tokenizer<M, N, PT, PP, D> {
padding: Option<PaddingParams>,
}
impl<M, N, PT, PP, D> Tokenizer<M, N, PT, PP, D>
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
where
M: Model,
N: Normalizer,
@ -352,7 +431,7 @@ where
{
/// Instantiate a new Tokenizer, with the given Model
pub fn new(model: M) -> Self {
Tokenizer {
TokenizerImpl {
normalizer: None,
pre_tokenizer: None,
model,
@ -367,8 +446,8 @@ where
}
/// Set the normalizer
pub fn with_normalizer(&mut self, normalizer: N) -> &Self {
self.normalizer = Some(normalizer);
pub fn with_normalizer(&mut self, normalizer: impl Into<N>) -> &Self {
self.normalizer = Some(normalizer.into());
self
}
@ -378,8 +457,8 @@ where
}
/// Set the pre tokenizer
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: PT) -> &Self {
self.pre_tokenizer = Some(pre_tokenizer);
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: impl Into<PT>) -> &Self {
self.pre_tokenizer = Some(pre_tokenizer.into());
self
}
@ -389,8 +468,8 @@ where
}
/// Set the post processor
pub fn with_post_processor(&mut self, post_processor: PP) -> &Self {
self.post_processor = Some(post_processor);
pub fn with_post_processor(&mut self, post_processor: impl Into<PP>) -> &Self {
self.post_processor = Some(post_processor.into());
self
}
@ -400,8 +479,8 @@ where
}
/// Set the decoder
pub fn with_decoder(&mut self, decoder: D) -> &Self {
self.decoder = Some(decoder);
pub fn with_decoder(&mut self, decoder: impl Into<D>) -> &Self {
self.decoder = Some(decoder.into());
self
}
@ -411,8 +490,8 @@ where
}
/// Set the model
pub fn with_model(&mut self, model: M) -> &mut Self {
self.model = model;
pub fn with_model(&mut self, model: impl Into<M>) -> &mut Self {
self.model = model.into();
self
}
@ -587,13 +666,7 @@ where
/// # use tokenizers::pre_tokenizers::PreTokenizerWrapper;
/// # use tokenizers::processors::PostProcessorWrapper;
/// # use tokenizers::decoders::DecoderWrapper;
/// # let mut tokenizer =
/// # Tokenizer::<_,
/// # NormalizerWrapper,
/// # PreTokenizerWrapper,
/// # PostProcessorWrapper,
/// # DecoderWrapper
/// # >::new(BPE::default());
/// # let mut tokenizer = Tokenizer::new(BPE::default());
/// #
/// // Sequences:
/// tokenizer.encode("Single sequence", false);
@ -772,7 +845,7 @@ where
self,
trainer: &T,
files: Vec<String>,
) -> Result<Tokenizer<TM, N, PT, PP, D>>
) -> Result<TokenizerImpl<TM, N, PT, PP, D>>
where
T: Trainer<Model = TM>,
TM: Model,
@ -780,7 +853,7 @@ where
let words = self.word_count(trainer, files)?;
let (model, special_tokens) = trainer.train(words)?;
let mut new_tok = Tokenizer {
let mut new_tok = TokenizerImpl {
normalizer: self.normalizer,
pre_tokenizer: self.pre_tokenizer,
model,
@ -932,7 +1005,7 @@ where
}
}
impl<M, N, PT, PP, D> std::str::FromStr for Tokenizer<M, N, PT, PP, D>
impl<M, N, PT, PP, D> std::str::FromStr for TokenizerImpl<M, N, PT, PP, D>
where
M: for<'de> Deserialize<'de> + Model,
N: for<'de> Deserialize<'de> + Normalizer,
@ -947,7 +1020,7 @@ where
}
}
impl<M, N, PT, PP, D> Tokenizer<M, N, PT, PP, D>
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
where
M: DeserializeOwned + Model,
N: DeserializeOwned + Normalizer,
@ -963,7 +1036,7 @@ where
}
}
impl<M, N, PT, PP, D> Tokenizer<M, N, PT, PP, D>
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
where
M: Serialize,
N: Serialize,

View File

@ -7,12 +7,12 @@ use serde::{
Deserialize, Deserializer, Serialize, Serializer,
};
use super::{added_vocabulary::AddedTokenWithId, Tokenizer};
use super::{added_vocabulary::AddedTokenWithId, TokenizerImpl};
use crate::{Decoder, Model, Normalizer, PostProcessor, PreTokenizer, TokenizerBuilder};
static SERIALIZATION_VERSION: &str = "1.0";
impl<M, N, PT, PP, D> Serialize for Tokenizer<M, N, PT, PP, D>
impl<M, N, PT, PP, D> Serialize for TokenizerImpl<M, N, PT, PP, D>
where
M: Serialize,
N: Serialize,
@ -47,7 +47,7 @@ where
}
}
impl<'de, M, N, PT, PP, D> Deserialize<'de> for Tokenizer<M, N, PT, PP, D>
impl<'de, M, N, PT, PP, D> Deserialize<'de> for TokenizerImpl<M, N, PT, PP, D>
where
M: Deserialize<'de> + Model,
N: Deserialize<'de> + Normalizer,
@ -99,7 +99,7 @@ where
PP: Deserialize<'de> + PostProcessor,
D: Deserialize<'de> + Decoder,
{
type Value = Tokenizer<M, N, PT, PP, D>;
type Value = TokenizerImpl<M, N, PT, PP, D>;
fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(fmt, "struct Tokenizer")

View File

@ -1,19 +1,14 @@
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
use tokenizers::decoders::DecoderWrapper;
use tokenizers::models::bpe::BPE;
use tokenizers::models::wordpiece::WordPiece;
use tokenizers::normalizers::bert::BertNormalizer;
use tokenizers::normalizers::NormalizerWrapper;
use tokenizers::pre_tokenizers::bert::BertPreTokenizer;
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::pre_tokenizers::PreTokenizerWrapper;
use tokenizers::processors::bert::BertProcessing;
use tokenizers::processors::PostProcessorWrapper;
use tokenizers::tokenizer::{Model, Tokenizer};
#[allow(dead_code)]
pub fn get_empty(
) -> Tokenizer<BPE, NormalizerWrapper, PreTokenizerWrapper, PostProcessorWrapper, DecoderWrapper> {
pub fn get_empty() -> Tokenizer {
Tokenizer::new(BPE::default())
}
@ -25,10 +20,7 @@ pub fn get_byte_level_bpe() -> BPE {
}
#[allow(dead_code)]
pub fn get_byte_level(
add_prefix_space: bool,
trim_offsets: bool,
) -> Tokenizer<BPE, NormalizerWrapper, ByteLevel, ByteLevel, ByteLevel> {
pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
let mut tokenizer = Tokenizer::new(get_byte_level_bpe());
tokenizer.with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space));
tokenizer.with_decoder(ByteLevel::default());
@ -45,21 +37,16 @@ pub fn get_bert_wordpiece() -> WordPiece {
}
#[allow(dead_code)]
pub fn get_bert(
) -> Tokenizer<WordPiece, BertNormalizer, BertPreTokenizer, BertProcessing, WordPieceDecoder> {
pub fn get_bert() -> Tokenizer {
let mut tokenizer = Tokenizer::new(get_bert_wordpiece());
tokenizer.with_normalizer(BertNormalizer::default());
tokenizer.with_pre_tokenizer(BertPreTokenizer);
tokenizer.with_decoder(WordPieceDecoder::default());
let sep = tokenizer.get_model().token_to_id("[SEP]").unwrap();
let cls = tokenizer.get_model().token_to_id("[CLS]").unwrap();
tokenizer.with_post_processor(BertProcessing::new(
(
String::from("[SEP]"),
tokenizer.get_model().token_to_id("[SEP]").unwrap(),
),
(
String::from("[CLS]"),
tokenizer.get_model().token_to_id("[CLS]").unwrap(),
),
(String::from("[SEP]"), sep),
(String::from("[CLS]"), cls),
));
tokenizer