Support None to reset pre_tokenizers and normalizers, and index sequences (#1590)

* initial commit

* support None

* fix clippy

* cleanup

* clean?

* propagate to pre_tokenizer

* fix test

* fix rust tests

* fix node

* propagate to decoder and post processor

* fix calls

* lint

* fmt

* node be happy I am fixing you

* initial commit

* support None

* fix clippy

* cleanup

* clean?

* propagate to pre_tokenizer

* fix test

* fix rust tests

* fix node

* propagate to decoder and post processor

* fix calls

* lint

* fmt

* node be happy I am fixing you

* add a small test

* styling

* style merge

* fix merge test

* fmt

* nits

* update tset
This commit is contained in:
Arthur
2024-08-07 12:52:35 +02:00
committed by GitHub
parent eea8e1ae6f
commit bded212356
13 changed files with 134 additions and 67 deletions

View File

@ -208,7 +208,7 @@ impl Tokenizer {
.tokenizer .tokenizer
.write() .write()
.unwrap() .unwrap()
.with_pre_tokenizer((*pre_tokenizer).clone()); .with_pre_tokenizer(Some((*pre_tokenizer).clone()));
} }
#[napi] #[napi]
@ -217,7 +217,7 @@ impl Tokenizer {
.tokenizer .tokenizer
.write() .write()
.unwrap() .unwrap()
.with_decoder((*decoder).clone()); .with_decoder(Some((*decoder).clone()));
} }
#[napi] #[napi]
@ -231,7 +231,7 @@ impl Tokenizer {
.tokenizer .tokenizer
.write() .write()
.unwrap() .unwrap()
.with_post_processor((*post_processor).clone()); .with_post_processor(Some((*post_processor).clone()));
} }
#[napi] #[napi]
@ -240,7 +240,7 @@ impl Tokenizer {
.tokenizer .tokenizer
.write() .write()
.unwrap() .unwrap()
.with_normalizer((*normalizer).clone()); .with_normalizer(Some((*normalizer).clone()));
} }
#[napi] #[napi]

View File

@ -1,8 +1,6 @@
use std::sync::{Arc, RwLock};
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*; use pyo3::types::*;
use pyo3::{exceptions, prelude::*};
use std::sync::{Arc, RwLock};
use crate::error::ToPyResult; use crate::error::ToPyResult;
use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern};
@ -354,6 +352,7 @@ impl PyNFKC {
/// A list of Normalizer to be run as a sequence /// A list of Normalizer to be run as a sequence
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Sequence")] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Sequence")]
pub struct PySequence {} pub struct PySequence {}
#[pymethods] #[pymethods]
impl PySequence { impl PySequence {
#[new] #[new]
@ -380,6 +379,22 @@ impl PySequence {
fn __len__(&self) -> usize { fn __len__(&self) -> usize {
0 0
} }
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
match &self_.as_ref().normalizer {
PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item)))
.get_as_subtype(py),
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
},
PyNormalizerTypeWrapper::Single(inner) => {
PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(inner)))
.get_as_subtype(py)
}
}
}
} }
/// Lowercase Normalizer /// Lowercase Normalizer

View File

@ -463,6 +463,24 @@ impl PySequence {
fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> { fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> {
PyTuple::new_bound(py, [PyList::empty_bound(py)]) PyTuple::new_bound(py, [PyList::empty_bound(py)])
} }
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
match &self_.as_ref().pretok {
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => {
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item)))
.get_as_subtype(py)
}
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
},
PyPreTokenizerTypeWrapper::Single(inner) => {
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner)))
.get_as_subtype(py)
}
}
}
} }
pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> { pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> {

View File

@ -1371,8 +1371,9 @@ impl PyTokenizer {
/// Set the :class:`~tokenizers.normalizers.Normalizer` /// Set the :class:`~tokenizers.normalizers.Normalizer`
#[setter] #[setter]
fn set_normalizer(&mut self, normalizer: PyRef<PyNormalizer>) { fn set_normalizer(&mut self, normalizer: Option<PyRef<PyNormalizer>>) {
self.tokenizer.with_normalizer(normalizer.clone()); let normalizer_option = normalizer.map(|norm| norm.clone());
self.tokenizer.with_normalizer(normalizer_option);
} }
/// The `optional` :class:`~tokenizers.pre_tokenizers.PreTokenizer` in use by the Tokenizer /// The `optional` :class:`~tokenizers.pre_tokenizers.PreTokenizer` in use by the Tokenizer
@ -1387,8 +1388,9 @@ impl PyTokenizer {
/// Set the :class:`~tokenizers.normalizers.Normalizer` /// Set the :class:`~tokenizers.normalizers.Normalizer`
#[setter] #[setter]
fn set_pre_tokenizer(&mut self, pretok: PyRef<PyPreTokenizer>) { fn set_pre_tokenizer(&mut self, pretok: Option<PyRef<PyPreTokenizer>>) {
self.tokenizer.with_pre_tokenizer(pretok.clone()); self.tokenizer
.with_pre_tokenizer(pretok.map(|pre| pre.clone()));
} }
/// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer /// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer
@ -1403,8 +1405,9 @@ impl PyTokenizer {
/// Set the :class:`~tokenizers.processors.PostProcessor` /// Set the :class:`~tokenizers.processors.PostProcessor`
#[setter] #[setter]
fn set_post_processor(&mut self, processor: PyRef<PyPostProcessor>) { fn set_post_processor(&mut self, processor: Option<PyRef<PyPostProcessor>>) {
self.tokenizer.with_post_processor(processor.clone()); self.tokenizer
.with_post_processor(processor.map(|p| p.clone()));
} }
/// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer /// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer
@ -1419,8 +1422,8 @@ impl PyTokenizer {
/// Set the :class:`~tokenizers.decoders.Decoder` /// Set the :class:`~tokenizers.decoders.Decoder`
#[setter] #[setter]
fn set_decoder(&mut self, decoder: PyRef<PyDecoder>) { fn set_decoder(&mut self, decoder: Option<PyRef<PyDecoder>>) {
self.tokenizer.with_decoder(decoder.clone()); self.tokenizer.with_decoder(decoder.map(|d| d.clone()));
} }
} }
@ -1436,10 +1439,12 @@ mod test {
#[test] #[test]
fn serialize() { fn serialize() {
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default())); let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![ tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(
vec![
Arc::new(RwLock::new(NFKC.into())), Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())), Arc::new(RwLock::new(Lowercase.into())),
]))); ],
))));
let tmp = NamedTempFile::new().unwrap().into_temp_path(); let tmp = NamedTempFile::new().unwrap().into_temp_path();
tokenizer.save(&tmp, false).unwrap(); tokenizer.save(&tmp, false).unwrap();
@ -1450,10 +1455,12 @@ mod test {
#[test] #[test]
fn serde_pyo3() { fn serde_pyo3() {
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default())); let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![ tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(
vec![
Arc::new(RwLock::new(NFKC.into())), Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())), Arc::new(RwLock::new(Lowercase.into())),
]))); ],
))));
let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap(); let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap();
assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))"); assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))");

View File

@ -67,6 +67,14 @@ class TestSequence:
output = normalizer.normalize_str(" HELLO ") output = normalizer.normalize_str(" HELLO ")
assert output == "hello" assert output == "hello"
def test_items(self):
normalizers = Sequence([BertNormalizer(True, True), Prepend()])
assert normalizers[1].__class__ == Prepend
normalizers[0].lowercase = False
assert not normalizers[0].lowercase
with pytest.raises(IndexError):
print(normalizers[2])
class TestLowercase: class TestLowercase:
def test_instantiate(self): def test_instantiate(self):

View File

@ -169,6 +169,13 @@ class TestSequence:
("?", (29, 30)), ("?", (29, 30)),
] ]
def test_items(self):
pre_tokenizers = Sequence([Metaspace("a", "never", split=True), Punctuation()])
assert pre_tokenizers[1].__class__ == Punctuation
assert pre_tokenizers[0].__class__ == Metaspace
pre_tokenizers[0].split = False
assert not pre_tokenizers[0].split
class TestDigits: class TestDigits:
def test_instantiate(self): def test_instantiate(self):

View File

@ -6,10 +6,11 @@ import pytest
from tokenizers import AddedToken, Encoding, Tokenizer from tokenizers import AddedToken, Encoding, Tokenizer
from tokenizers.implementations import BertWordPieceTokenizer from tokenizers.implementations import BertWordPieceTokenizer
from tokenizers.models import BPE, Model, Unigram from tokenizers.models import BPE, Model, Unigram
from tokenizers.pre_tokenizers import ByteLevel from tokenizers.pre_tokenizers import ByteLevel, Metaspace
from tokenizers.processors import RobertaProcessing, TemplateProcessing from tokenizers.processors import RobertaProcessing, TemplateProcessing
from tokenizers.normalizers import Strip, Lowercase, Sequence from tokenizers.normalizers import Strip, Lowercase, Sequence
from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files
@ -551,6 +552,16 @@ class TestTokenizer:
assert output == "name is john" assert output == "name is john"
assert tokenizer.get_added_tokens_decoder()[0] == AddedToken("my", special=True) assert tokenizer.get_added_tokens_decoder()[0] == AddedToken("my", special=True)
def test_setting_to_none(self):
tokenizer = Tokenizer(BPE())
tokenizer.normalizer = Strip()
tokenizer.normalizer = None
assert tokenizer.normalizer == None
tokenizer.pre_tokenizer = Metaspace()
tokenizer.pre_tokenizer = None
assert tokenizer.pre_tokenizer == None
class TestTokenizerRepr: class TestTokenizerRepr:
def test_repr(self): def test_repr(self):

View File

@ -34,13 +34,13 @@ fn create_bert_tokenizer(wp: WordPiece) -> BertTokenizer {
let sep_id = *wp.get_vocab().get("[SEP]").unwrap(); let sep_id = *wp.get_vocab().get("[SEP]").unwrap();
let cls_id = *wp.get_vocab().get("[CLS]").unwrap(); let cls_id = *wp.get_vocab().get("[CLS]").unwrap();
let mut tokenizer = TokenizerImpl::new(wp); let mut tokenizer = TokenizerImpl::new(wp);
tokenizer.with_pre_tokenizer(BertPreTokenizer); tokenizer.with_pre_tokenizer(Some(BertPreTokenizer));
tokenizer.with_normalizer(BertNormalizer::default()); tokenizer.with_normalizer(Some(BertNormalizer::default()));
tokenizer.with_decoder(decoders::wordpiece::WordPiece::default()); tokenizer.with_decoder(Some(decoders::wordpiece::WordPiece::default()));
tokenizer.with_post_processor(BertProcessing::new( tokenizer.with_post_processor(Some(BertProcessing::new(
("[SEP]".to_string(), sep_id), ("[SEP]".to_string(), sep_id),
("[CLS]".to_string(), cls_id), ("[CLS]".to_string(), cls_id),
)); )));
tokenizer tokenizer
} }
@ -81,7 +81,7 @@ fn bench_train(c: &mut Criterion) {
DecoderWrapper, DecoderWrapper,
>; >;
let mut tokenizer = Tok::new(WordPiece::default()); let mut tokenizer = Tok::new(WordPiece::default());
tokenizer.with_pre_tokenizer(Whitespace {}); tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("WordPiece Train vocabulary (small)", |b| { c.bench_function("WordPiece Train vocabulary (small)", |b| {
b.iter_custom(|iters| { b.iter_custom(|iters| {
iter_bench_train( iter_bench_train(
@ -94,7 +94,7 @@ fn bench_train(c: &mut Criterion) {
}); });
let mut tokenizer = Tok::new(WordPiece::default()); let mut tokenizer = Tok::new(WordPiece::default());
tokenizer.with_pre_tokenizer(Whitespace {}); tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("WordPiece Train vocabulary (big)", |b| { c.bench_function("WordPiece Train vocabulary (big)", |b| {
b.iter_custom(|iters| { b.iter_custom(|iters| {
iter_bench_train( iter_bench_train(

View File

@ -22,8 +22,8 @@ static BATCH_SIZE: usize = 1_000;
fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer { fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
let mut tokenizer = Tokenizer::new(bpe); let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(ByteLevel::default()); tokenizer.with_pre_tokenizer(Some(ByteLevel::default()));
tokenizer.with_decoder(ByteLevel::default()); tokenizer.with_decoder(Some(ByteLevel::default()));
tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]); tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]);
tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]); tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]);
tokenizer tokenizer
@ -74,7 +74,7 @@ fn bench_train(c: &mut Criterion) {
.build() .build()
.into(); .into();
let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); let mut tokenizer = Tokenizer::new(BPE::default()).into_inner();
tokenizer.with_pre_tokenizer(Whitespace {}); tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("BPE Train vocabulary (small)", |b| { c.bench_function("BPE Train vocabulary (small)", |b| {
b.iter_custom(|iters| { b.iter_custom(|iters| {
iter_bench_train( iter_bench_train(
@ -87,7 +87,7 @@ fn bench_train(c: &mut Criterion) {
}); });
let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); let mut tokenizer = Tokenizer::new(BPE::default()).into_inner();
tokenizer.with_pre_tokenizer(Whitespace {}); tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("BPE Train vocabulary (big)", |b| { c.bench_function("BPE Train vocabulary (big)", |b| {
b.iter_custom(|iters| { b.iter_custom(|iters| {
iter_bench_train( iter_bench_train(

View File

@ -550,19 +550,18 @@ where
} }
/// Set the normalizer /// Set the normalizer
pub fn with_normalizer(&mut self, normalizer: impl Into<N>) -> &mut Self { pub fn with_normalizer(&mut self, normalizer: Option<impl Into<N>>) -> &mut Self {
self.normalizer = Some(normalizer.into()); self.normalizer = normalizer.map(|norm| norm.into());
self self
} }
/// Get the normalizer /// Get the normalizer
pub fn get_normalizer(&self) -> Option<&N> { pub fn get_normalizer(&self) -> Option<&N> {
self.normalizer.as_ref() self.normalizer.as_ref()
} }
/// Set the pre tokenizer /// Set the pre tokenizer
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: impl Into<PT>) -> &mut Self { pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Option<impl Into<PT>>) -> &mut Self {
self.pre_tokenizer = Some(pre_tokenizer.into()); self.pre_tokenizer = pre_tokenizer.map(|tok| tok.into());
self self
} }
@ -572,8 +571,8 @@ where
} }
/// Set the post processor /// Set the post processor
pub fn with_post_processor(&mut self, post_processor: impl Into<PP>) -> &mut Self { pub fn with_post_processor(&mut self, post_processor: Option<impl Into<PP>>) -> &mut Self {
self.post_processor = Some(post_processor.into()); self.post_processor = post_processor.map(|post_proc| post_proc.into());
self self
} }
@ -583,8 +582,8 @@ where
} }
/// Set the decoder /// Set the decoder
pub fn with_decoder(&mut self, decoder: impl Into<D>) -> &mut Self { pub fn with_decoder(&mut self, decoder: Option<impl Into<D>>) -> &mut Self {
self.decoder = Some(decoder.into()); self.decoder = decoder.map(|dec| dec.into());
self self
} }

View File

@ -23,9 +23,11 @@ pub fn get_byte_level_bpe() -> BPE {
pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer { pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
let mut tokenizer = Tokenizer::new(get_byte_level_bpe()); let mut tokenizer = Tokenizer::new(get_byte_level_bpe());
tokenizer tokenizer
.with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space)) .with_pre_tokenizer(Some(
.with_decoder(ByteLevel::default()) ByteLevel::default().add_prefix_space(add_prefix_space),
.with_post_processor(ByteLevel::default().trim_offsets(trim_offsets)); ))
.with_decoder(Some(ByteLevel::default()))
.with_post_processor(Some(ByteLevel::default().trim_offsets(trim_offsets)));
tokenizer tokenizer
} }
@ -43,13 +45,13 @@ pub fn get_bert() -> Tokenizer {
let sep = tokenizer.get_model().token_to_id("[SEP]").unwrap(); let sep = tokenizer.get_model().token_to_id("[SEP]").unwrap();
let cls = tokenizer.get_model().token_to_id("[CLS]").unwrap(); let cls = tokenizer.get_model().token_to_id("[CLS]").unwrap();
tokenizer tokenizer
.with_normalizer(BertNormalizer::default()) .with_normalizer(Some(BertNormalizer::default()))
.with_pre_tokenizer(BertPreTokenizer) .with_pre_tokenizer(Some(BertPreTokenizer))
.with_decoder(WordPieceDecoder::default()) .with_decoder(Some(WordPieceDecoder::default()))
.with_post_processor(BertProcessing::new( .with_post_processor(Some(BertProcessing::new(
(String::from("[SEP]"), sep), (String::from("[SEP]"), sep),
(String::from("[CLS]"), cls), (String::from("[CLS]"), cls),
)); )));
tokenizer tokenizer
} }

View File

@ -93,7 +93,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> {
// START quicktour_init_pretok // START quicktour_init_pretok
use tokenizers::pre_tokenizers::whitespace::Whitespace; use tokenizers::pre_tokenizers::whitespace::Whitespace;
tokenizer.with_pre_tokenizer(Whitespace {}); tokenizer.with_pre_tokenizer(Some(Whitespace {}));
// END quicktour_init_pretok // END quicktour_init_pretok
// START quicktour_train // START quicktour_train
@ -157,7 +157,7 @@ fn quicktour() -> tokenizers::Result<()> {
("[CLS]", tokenizer.token_to_id("[CLS]").unwrap()), ("[CLS]", tokenizer.token_to_id("[CLS]").unwrap()),
("[SEP]", tokenizer.token_to_id("[SEP]").unwrap()), ("[SEP]", tokenizer.token_to_id("[SEP]").unwrap()),
]; ];
tokenizer.with_post_processor( tokenizer.with_post_processor(Some(
TemplateProcessing::builder() TemplateProcessing::builder()
.try_single("[CLS] $A [SEP]") .try_single("[CLS] $A [SEP]")
.unwrap() .unwrap()
@ -165,7 +165,7 @@ fn quicktour() -> tokenizers::Result<()> {
.unwrap() .unwrap()
.special_tokens(special_tokens) .special_tokens(special_tokens)
.build()?, .build()?,
); ));
// END quicktour_init_template_processing // END quicktour_init_template_processing
// START quicktour_print_special_tokens // START quicktour_print_special_tokens
let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?; let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?;
@ -261,7 +261,7 @@ fn pipeline() -> tokenizers::Result<()> {
// END pipeline_test_normalizer // END pipeline_test_normalizer
assert_eq!(normalized.get(), "Hello how are u?"); assert_eq!(normalized.get(), "Hello how are u?");
// START pipeline_replace_normalizer // START pipeline_replace_normalizer
tokenizer.with_normalizer(normalizer); tokenizer.with_normalizer(Some(normalizer));
// END pipeline_replace_normalizer // END pipeline_replace_normalizer
// START pipeline_setup_pre_tokenizer // START pipeline_setup_pre_tokenizer
use tokenizers::pre_tokenizers::whitespace::Whitespace; use tokenizers::pre_tokenizers::whitespace::Whitespace;
@ -325,12 +325,12 @@ fn pipeline() -> tokenizers::Result<()> {
] ]
); );
// START pipeline_replace_pre_tokenizer // START pipeline_replace_pre_tokenizer
tokenizer.with_pre_tokenizer(pre_tokenizer); tokenizer.with_pre_tokenizer(Some(pre_tokenizer));
// END pipeline_replace_pre_tokenizer // END pipeline_replace_pre_tokenizer
// START pipeline_setup_processor // START pipeline_setup_processor
use tokenizers::processors::template::TemplateProcessing; use tokenizers::processors::template::TemplateProcessing;
tokenizer.with_post_processor( tokenizer.with_post_processor(Some(
TemplateProcessing::builder() TemplateProcessing::builder()
.try_single("[CLS] $A [SEP]") .try_single("[CLS] $A [SEP]")
.unwrap() .unwrap()
@ -339,7 +339,7 @@ fn pipeline() -> tokenizers::Result<()> {
.special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)]) .special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)])
.build() .build()
.unwrap(), .unwrap(),
); ));
// END pipeline_setup_processor // END pipeline_setup_processor
// START pipeline_test_decoding // START pipeline_test_decoding
let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?; let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?;
@ -375,21 +375,21 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
use tokenizers::normalizers::utils::Sequence as NormalizerSequence; use tokenizers::normalizers::utils::Sequence as NormalizerSequence;
use tokenizers::normalizers::{strip::StripAccents, unicode::NFD, utils::Lowercase}; use tokenizers::normalizers::{strip::StripAccents, unicode::NFD, utils::Lowercase};
bert_tokenizer.with_normalizer(NormalizerSequence::new(vec![ bert_tokenizer.with_normalizer(Some(NormalizerSequence::new(vec![
NFD.into(), NFD.into(),
Lowercase.into(), Lowercase.into(),
StripAccents.into(), StripAccents.into(),
])); ])));
// END bert_setup_normalizer // END bert_setup_normalizer
// START bert_setup_pre_tokenizer // START bert_setup_pre_tokenizer
use tokenizers::pre_tokenizers::whitespace::Whitespace; use tokenizers::pre_tokenizers::whitespace::Whitespace;
bert_tokenizer.with_pre_tokenizer(Whitespace {}); bert_tokenizer.with_pre_tokenizer(Some(Whitespace {}));
// END bert_setup_pre_tokenizer // END bert_setup_pre_tokenizer
// START bert_setup_processor // START bert_setup_processor
use tokenizers::processors::template::TemplateProcessing; use tokenizers::processors::template::TemplateProcessing;
bert_tokenizer.with_post_processor( bert_tokenizer.with_post_processor(Some(
TemplateProcessing::builder() TemplateProcessing::builder()
.try_single("[CLS] $A [SEP]") .try_single("[CLS] $A [SEP]")
.unwrap() .unwrap()
@ -398,7 +398,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
.special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)]) .special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)])
.build() .build()
.unwrap(), .unwrap(),
); ));
// END bert_setup_processor // END bert_setup_processor
// START bert_train_tokenizer // START bert_train_tokenizer
use tokenizers::models::{wordpiece::WordPieceTrainer, TrainerWrapper}; use tokenizers::models::{wordpiece::WordPieceTrainer, TrainerWrapper};
@ -450,7 +450,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
// START bert_proper_decoding // START bert_proper_decoding
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder; use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
bert_tokenizer.with_decoder(WordPieceDecoder::default()); bert_tokenizer.with_decoder(Some(WordPieceDecoder::default()));
let decoded = bert_tokenizer.decode(output.get_ids(), true)?; let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
// "welcome to the tokenizers library." // "welcome to the tokenizers library."
// END bert_proper_decoding // END bert_proper_decoding

View File

@ -203,7 +203,7 @@ fn models() {
fn tokenizer() { fn tokenizer() {
let wordpiece = WordPiece::default(); let wordpiece = WordPiece::default();
let mut tokenizer = Tokenizer::new(wordpiece); let mut tokenizer = Tokenizer::new(wordpiece);
tokenizer.with_normalizer(NFC); tokenizer.with_normalizer(Some(NFC));
let ser = serde_json::to_string(&tokenizer).unwrap(); let ser = serde_json::to_string(&tokenizer).unwrap();
let _: Tokenizer = serde_json::from_str(&ser).unwrap(); let _: Tokenizer = serde_json::from_str(&ser).unwrap();
let unwrapped_nfc_tok: TokenizerImpl< let unwrapped_nfc_tok: TokenizerImpl<