mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Support None
to reset pre_tokenizers and normalizers, and index sequences (#1590)
* initial commit * support None * fix clippy * cleanup * clean? * propagate to pre_tokenizer * fix test * fix rust tests * fix node * propagate to decoder and post processor * fix calls * lint * fmt * node be happy I am fixing you * initial commit * support None * fix clippy * cleanup * clean? * propagate to pre_tokenizer * fix test * fix rust tests * fix node * propagate to decoder and post processor * fix calls * lint * fmt * node be happy I am fixing you * add a small test * styling * style merge * fix merge test * fmt * nits * update tset
This commit is contained in:
@ -208,7 +208,7 @@ impl Tokenizer {
|
|||||||
.tokenizer
|
.tokenizer
|
||||||
.write()
|
.write()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.with_pre_tokenizer((*pre_tokenizer).clone());
|
.with_pre_tokenizer(Some((*pre_tokenizer).clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
@ -217,7 +217,7 @@ impl Tokenizer {
|
|||||||
.tokenizer
|
.tokenizer
|
||||||
.write()
|
.write()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.with_decoder((*decoder).clone());
|
.with_decoder(Some((*decoder).clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
@ -231,7 +231,7 @@ impl Tokenizer {
|
|||||||
.tokenizer
|
.tokenizer
|
||||||
.write()
|
.write()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.with_post_processor((*post_processor).clone());
|
.with_post_processor(Some((*post_processor).clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
@ -240,7 +240,7 @@ impl Tokenizer {
|
|||||||
.tokenizer
|
.tokenizer
|
||||||
.write()
|
.write()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.with_normalizer((*normalizer).clone());
|
.with_normalizer(Some((*normalizer).clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
use std::sync::{Arc, RwLock};
|
|
||||||
|
|
||||||
use pyo3::exceptions;
|
|
||||||
use pyo3::prelude::*;
|
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
|
use pyo3::{exceptions, prelude::*};
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
use crate::error::ToPyResult;
|
use crate::error::ToPyResult;
|
||||||
use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern};
|
use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern};
|
||||||
@ -354,6 +352,7 @@ impl PyNFKC {
|
|||||||
/// A list of Normalizer to be run as a sequence
|
/// A list of Normalizer to be run as a sequence
|
||||||
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Sequence")]
|
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Sequence")]
|
||||||
pub struct PySequence {}
|
pub struct PySequence {}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PySequence {
|
impl PySequence {
|
||||||
#[new]
|
#[new]
|
||||||
@ -380,6 +379,22 @@ impl PySequence {
|
|||||||
fn __len__(&self) -> usize {
|
fn __len__(&self) -> usize {
|
||||||
0
|
0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
|
||||||
|
match &self_.as_ref().normalizer {
|
||||||
|
PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) {
|
||||||
|
Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item)))
|
||||||
|
.get_as_subtype(py),
|
||||||
|
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||||
|
"Index not found",
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
PyNormalizerTypeWrapper::Single(inner) => {
|
||||||
|
PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(inner)))
|
||||||
|
.get_as_subtype(py)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lowercase Normalizer
|
/// Lowercase Normalizer
|
||||||
|
@ -463,6 +463,24 @@ impl PySequence {
|
|||||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> {
|
fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> {
|
||||||
PyTuple::new_bound(py, [PyList::empty_bound(py)])
|
PyTuple::new_bound(py, [PyList::empty_bound(py)])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
|
||||||
|
match &self_.as_ref().pretok {
|
||||||
|
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
|
||||||
|
Some(item) => {
|
||||||
|
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item)))
|
||||||
|
.get_as_subtype(py)
|
||||||
|
}
|
||||||
|
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||||
|
"Index not found",
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
PyPreTokenizerTypeWrapper::Single(inner) => {
|
||||||
|
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner)))
|
||||||
|
.get_as_subtype(py)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> {
|
pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> {
|
||||||
|
@ -1371,8 +1371,9 @@ impl PyTokenizer {
|
|||||||
|
|
||||||
/// Set the :class:`~tokenizers.normalizers.Normalizer`
|
/// Set the :class:`~tokenizers.normalizers.Normalizer`
|
||||||
#[setter]
|
#[setter]
|
||||||
fn set_normalizer(&mut self, normalizer: PyRef<PyNormalizer>) {
|
fn set_normalizer(&mut self, normalizer: Option<PyRef<PyNormalizer>>) {
|
||||||
self.tokenizer.with_normalizer(normalizer.clone());
|
let normalizer_option = normalizer.map(|norm| norm.clone());
|
||||||
|
self.tokenizer.with_normalizer(normalizer_option);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The `optional` :class:`~tokenizers.pre_tokenizers.PreTokenizer` in use by the Tokenizer
|
/// The `optional` :class:`~tokenizers.pre_tokenizers.PreTokenizer` in use by the Tokenizer
|
||||||
@ -1387,8 +1388,9 @@ impl PyTokenizer {
|
|||||||
|
|
||||||
/// Set the :class:`~tokenizers.normalizers.Normalizer`
|
/// Set the :class:`~tokenizers.normalizers.Normalizer`
|
||||||
#[setter]
|
#[setter]
|
||||||
fn set_pre_tokenizer(&mut self, pretok: PyRef<PyPreTokenizer>) {
|
fn set_pre_tokenizer(&mut self, pretok: Option<PyRef<PyPreTokenizer>>) {
|
||||||
self.tokenizer.with_pre_tokenizer(pretok.clone());
|
self.tokenizer
|
||||||
|
.with_pre_tokenizer(pretok.map(|pre| pre.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer
|
/// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer
|
||||||
@ -1403,8 +1405,9 @@ impl PyTokenizer {
|
|||||||
|
|
||||||
/// Set the :class:`~tokenizers.processors.PostProcessor`
|
/// Set the :class:`~tokenizers.processors.PostProcessor`
|
||||||
#[setter]
|
#[setter]
|
||||||
fn set_post_processor(&mut self, processor: PyRef<PyPostProcessor>) {
|
fn set_post_processor(&mut self, processor: Option<PyRef<PyPostProcessor>>) {
|
||||||
self.tokenizer.with_post_processor(processor.clone());
|
self.tokenizer
|
||||||
|
.with_post_processor(processor.map(|p| p.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer
|
/// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer
|
||||||
@ -1419,8 +1422,8 @@ impl PyTokenizer {
|
|||||||
|
|
||||||
/// Set the :class:`~tokenizers.decoders.Decoder`
|
/// Set the :class:`~tokenizers.decoders.Decoder`
|
||||||
#[setter]
|
#[setter]
|
||||||
fn set_decoder(&mut self, decoder: PyRef<PyDecoder>) {
|
fn set_decoder(&mut self, decoder: Option<PyRef<PyDecoder>>) {
|
||||||
self.tokenizer.with_decoder(decoder.clone());
|
self.tokenizer.with_decoder(decoder.map(|d| d.clone()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1436,10 +1439,12 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn serialize() {
|
fn serialize() {
|
||||||
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
|
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
|
||||||
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
|
tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(
|
||||||
|
vec![
|
||||||
Arc::new(RwLock::new(NFKC.into())),
|
Arc::new(RwLock::new(NFKC.into())),
|
||||||
Arc::new(RwLock::new(Lowercase.into())),
|
Arc::new(RwLock::new(Lowercase.into())),
|
||||||
])));
|
],
|
||||||
|
))));
|
||||||
|
|
||||||
let tmp = NamedTempFile::new().unwrap().into_temp_path();
|
let tmp = NamedTempFile::new().unwrap().into_temp_path();
|
||||||
tokenizer.save(&tmp, false).unwrap();
|
tokenizer.save(&tmp, false).unwrap();
|
||||||
@ -1450,10 +1455,12 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn serde_pyo3() {
|
fn serde_pyo3() {
|
||||||
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
|
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
|
||||||
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
|
tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(
|
||||||
|
vec![
|
||||||
Arc::new(RwLock::new(NFKC.into())),
|
Arc::new(RwLock::new(NFKC.into())),
|
||||||
Arc::new(RwLock::new(Lowercase.into())),
|
Arc::new(RwLock::new(Lowercase.into())),
|
||||||
])));
|
],
|
||||||
|
))));
|
||||||
|
|
||||||
let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap();
|
let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap();
|
||||||
assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))");
|
assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))");
|
||||||
|
@ -67,6 +67,14 @@ class TestSequence:
|
|||||||
output = normalizer.normalize_str(" HELLO ")
|
output = normalizer.normalize_str(" HELLO ")
|
||||||
assert output == "hello"
|
assert output == "hello"
|
||||||
|
|
||||||
|
def test_items(self):
|
||||||
|
normalizers = Sequence([BertNormalizer(True, True), Prepend()])
|
||||||
|
assert normalizers[1].__class__ == Prepend
|
||||||
|
normalizers[0].lowercase = False
|
||||||
|
assert not normalizers[0].lowercase
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
print(normalizers[2])
|
||||||
|
|
||||||
|
|
||||||
class TestLowercase:
|
class TestLowercase:
|
||||||
def test_instantiate(self):
|
def test_instantiate(self):
|
||||||
|
@ -169,6 +169,13 @@ class TestSequence:
|
|||||||
("?", (29, 30)),
|
("?", (29, 30)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_items(self):
|
||||||
|
pre_tokenizers = Sequence([Metaspace("a", "never", split=True), Punctuation()])
|
||||||
|
assert pre_tokenizers[1].__class__ == Punctuation
|
||||||
|
assert pre_tokenizers[0].__class__ == Metaspace
|
||||||
|
pre_tokenizers[0].split = False
|
||||||
|
assert not pre_tokenizers[0].split
|
||||||
|
|
||||||
|
|
||||||
class TestDigits:
|
class TestDigits:
|
||||||
def test_instantiate(self):
|
def test_instantiate(self):
|
||||||
|
@ -6,10 +6,11 @@ import pytest
|
|||||||
from tokenizers import AddedToken, Encoding, Tokenizer
|
from tokenizers import AddedToken, Encoding, Tokenizer
|
||||||
from tokenizers.implementations import BertWordPieceTokenizer
|
from tokenizers.implementations import BertWordPieceTokenizer
|
||||||
from tokenizers.models import BPE, Model, Unigram
|
from tokenizers.models import BPE, Model, Unigram
|
||||||
from tokenizers.pre_tokenizers import ByteLevel
|
from tokenizers.pre_tokenizers import ByteLevel, Metaspace
|
||||||
from tokenizers.processors import RobertaProcessing, TemplateProcessing
|
from tokenizers.processors import RobertaProcessing, TemplateProcessing
|
||||||
from tokenizers.normalizers import Strip, Lowercase, Sequence
|
from tokenizers.normalizers import Strip, Lowercase, Sequence
|
||||||
|
|
||||||
|
|
||||||
from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files
|
from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files
|
||||||
|
|
||||||
|
|
||||||
@ -551,6 +552,16 @@ class TestTokenizer:
|
|||||||
assert output == "name is john"
|
assert output == "name is john"
|
||||||
assert tokenizer.get_added_tokens_decoder()[0] == AddedToken("my", special=True)
|
assert tokenizer.get_added_tokens_decoder()[0] == AddedToken("my", special=True)
|
||||||
|
|
||||||
|
def test_setting_to_none(self):
|
||||||
|
tokenizer = Tokenizer(BPE())
|
||||||
|
tokenizer.normalizer = Strip()
|
||||||
|
tokenizer.normalizer = None
|
||||||
|
assert tokenizer.normalizer == None
|
||||||
|
|
||||||
|
tokenizer.pre_tokenizer = Metaspace()
|
||||||
|
tokenizer.pre_tokenizer = None
|
||||||
|
assert tokenizer.pre_tokenizer == None
|
||||||
|
|
||||||
|
|
||||||
class TestTokenizerRepr:
|
class TestTokenizerRepr:
|
||||||
def test_repr(self):
|
def test_repr(self):
|
||||||
|
@ -34,13 +34,13 @@ fn create_bert_tokenizer(wp: WordPiece) -> BertTokenizer {
|
|||||||
let sep_id = *wp.get_vocab().get("[SEP]").unwrap();
|
let sep_id = *wp.get_vocab().get("[SEP]").unwrap();
|
||||||
let cls_id = *wp.get_vocab().get("[CLS]").unwrap();
|
let cls_id = *wp.get_vocab().get("[CLS]").unwrap();
|
||||||
let mut tokenizer = TokenizerImpl::new(wp);
|
let mut tokenizer = TokenizerImpl::new(wp);
|
||||||
tokenizer.with_pre_tokenizer(BertPreTokenizer);
|
tokenizer.with_pre_tokenizer(Some(BertPreTokenizer));
|
||||||
tokenizer.with_normalizer(BertNormalizer::default());
|
tokenizer.with_normalizer(Some(BertNormalizer::default()));
|
||||||
tokenizer.with_decoder(decoders::wordpiece::WordPiece::default());
|
tokenizer.with_decoder(Some(decoders::wordpiece::WordPiece::default()));
|
||||||
tokenizer.with_post_processor(BertProcessing::new(
|
tokenizer.with_post_processor(Some(BertProcessing::new(
|
||||||
("[SEP]".to_string(), sep_id),
|
("[SEP]".to_string(), sep_id),
|
||||||
("[CLS]".to_string(), cls_id),
|
("[CLS]".to_string(), cls_id),
|
||||||
));
|
)));
|
||||||
tokenizer
|
tokenizer
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ fn bench_train(c: &mut Criterion) {
|
|||||||
DecoderWrapper,
|
DecoderWrapper,
|
||||||
>;
|
>;
|
||||||
let mut tokenizer = Tok::new(WordPiece::default());
|
let mut tokenizer = Tok::new(WordPiece::default());
|
||||||
tokenizer.with_pre_tokenizer(Whitespace {});
|
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
|
||||||
c.bench_function("WordPiece Train vocabulary (small)", |b| {
|
c.bench_function("WordPiece Train vocabulary (small)", |b| {
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
iter_bench_train(
|
iter_bench_train(
|
||||||
@ -94,7 +94,7 @@ fn bench_train(c: &mut Criterion) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let mut tokenizer = Tok::new(WordPiece::default());
|
let mut tokenizer = Tok::new(WordPiece::default());
|
||||||
tokenizer.with_pre_tokenizer(Whitespace {});
|
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
|
||||||
c.bench_function("WordPiece Train vocabulary (big)", |b| {
|
c.bench_function("WordPiece Train vocabulary (big)", |b| {
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
iter_bench_train(
|
iter_bench_train(
|
||||||
|
@ -22,8 +22,8 @@ static BATCH_SIZE: usize = 1_000;
|
|||||||
|
|
||||||
fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
|
fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
|
||||||
let mut tokenizer = Tokenizer::new(bpe);
|
let mut tokenizer = Tokenizer::new(bpe);
|
||||||
tokenizer.with_pre_tokenizer(ByteLevel::default());
|
tokenizer.with_pre_tokenizer(Some(ByteLevel::default()));
|
||||||
tokenizer.with_decoder(ByteLevel::default());
|
tokenizer.with_decoder(Some(ByteLevel::default()));
|
||||||
tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]);
|
tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]);
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]);
|
tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]);
|
||||||
tokenizer
|
tokenizer
|
||||||
@ -74,7 +74,7 @@ fn bench_train(c: &mut Criterion) {
|
|||||||
.build()
|
.build()
|
||||||
.into();
|
.into();
|
||||||
let mut tokenizer = Tokenizer::new(BPE::default()).into_inner();
|
let mut tokenizer = Tokenizer::new(BPE::default()).into_inner();
|
||||||
tokenizer.with_pre_tokenizer(Whitespace {});
|
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
|
||||||
c.bench_function("BPE Train vocabulary (small)", |b| {
|
c.bench_function("BPE Train vocabulary (small)", |b| {
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
iter_bench_train(
|
iter_bench_train(
|
||||||
@ -87,7 +87,7 @@ fn bench_train(c: &mut Criterion) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let mut tokenizer = Tokenizer::new(BPE::default()).into_inner();
|
let mut tokenizer = Tokenizer::new(BPE::default()).into_inner();
|
||||||
tokenizer.with_pre_tokenizer(Whitespace {});
|
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
|
||||||
c.bench_function("BPE Train vocabulary (big)", |b| {
|
c.bench_function("BPE Train vocabulary (big)", |b| {
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
iter_bench_train(
|
iter_bench_train(
|
||||||
|
@ -550,19 +550,18 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Set the normalizer
|
/// Set the normalizer
|
||||||
pub fn with_normalizer(&mut self, normalizer: impl Into<N>) -> &mut Self {
|
pub fn with_normalizer(&mut self, normalizer: Option<impl Into<N>>) -> &mut Self {
|
||||||
self.normalizer = Some(normalizer.into());
|
self.normalizer = normalizer.map(|norm| norm.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the normalizer
|
/// Get the normalizer
|
||||||
pub fn get_normalizer(&self) -> Option<&N> {
|
pub fn get_normalizer(&self) -> Option<&N> {
|
||||||
self.normalizer.as_ref()
|
self.normalizer.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the pre tokenizer
|
/// Set the pre tokenizer
|
||||||
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: impl Into<PT>) -> &mut Self {
|
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Option<impl Into<PT>>) -> &mut Self {
|
||||||
self.pre_tokenizer = Some(pre_tokenizer.into());
|
self.pre_tokenizer = pre_tokenizer.map(|tok| tok.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -572,8 +571,8 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Set the post processor
|
/// Set the post processor
|
||||||
pub fn with_post_processor(&mut self, post_processor: impl Into<PP>) -> &mut Self {
|
pub fn with_post_processor(&mut self, post_processor: Option<impl Into<PP>>) -> &mut Self {
|
||||||
self.post_processor = Some(post_processor.into());
|
self.post_processor = post_processor.map(|post_proc| post_proc.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -583,8 +582,8 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Set the decoder
|
/// Set the decoder
|
||||||
pub fn with_decoder(&mut self, decoder: impl Into<D>) -> &mut Self {
|
pub fn with_decoder(&mut self, decoder: Option<impl Into<D>>) -> &mut Self {
|
||||||
self.decoder = Some(decoder.into());
|
self.decoder = decoder.map(|dec| dec.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,9 +23,11 @@ pub fn get_byte_level_bpe() -> BPE {
|
|||||||
pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
|
pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
|
||||||
let mut tokenizer = Tokenizer::new(get_byte_level_bpe());
|
let mut tokenizer = Tokenizer::new(get_byte_level_bpe());
|
||||||
tokenizer
|
tokenizer
|
||||||
.with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space))
|
.with_pre_tokenizer(Some(
|
||||||
.with_decoder(ByteLevel::default())
|
ByteLevel::default().add_prefix_space(add_prefix_space),
|
||||||
.with_post_processor(ByteLevel::default().trim_offsets(trim_offsets));
|
))
|
||||||
|
.with_decoder(Some(ByteLevel::default()))
|
||||||
|
.with_post_processor(Some(ByteLevel::default().trim_offsets(trim_offsets)));
|
||||||
|
|
||||||
tokenizer
|
tokenizer
|
||||||
}
|
}
|
||||||
@ -43,13 +45,13 @@ pub fn get_bert() -> Tokenizer {
|
|||||||
let sep = tokenizer.get_model().token_to_id("[SEP]").unwrap();
|
let sep = tokenizer.get_model().token_to_id("[SEP]").unwrap();
|
||||||
let cls = tokenizer.get_model().token_to_id("[CLS]").unwrap();
|
let cls = tokenizer.get_model().token_to_id("[CLS]").unwrap();
|
||||||
tokenizer
|
tokenizer
|
||||||
.with_normalizer(BertNormalizer::default())
|
.with_normalizer(Some(BertNormalizer::default()))
|
||||||
.with_pre_tokenizer(BertPreTokenizer)
|
.with_pre_tokenizer(Some(BertPreTokenizer))
|
||||||
.with_decoder(WordPieceDecoder::default())
|
.with_decoder(Some(WordPieceDecoder::default()))
|
||||||
.with_post_processor(BertProcessing::new(
|
.with_post_processor(Some(BertProcessing::new(
|
||||||
(String::from("[SEP]"), sep),
|
(String::from("[SEP]"), sep),
|
||||||
(String::from("[CLS]"), cls),
|
(String::from("[CLS]"), cls),
|
||||||
));
|
)));
|
||||||
|
|
||||||
tokenizer
|
tokenizer
|
||||||
}
|
}
|
||||||
|
@ -93,7 +93,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> {
|
|||||||
// START quicktour_init_pretok
|
// START quicktour_init_pretok
|
||||||
use tokenizers::pre_tokenizers::whitespace::Whitespace;
|
use tokenizers::pre_tokenizers::whitespace::Whitespace;
|
||||||
|
|
||||||
tokenizer.with_pre_tokenizer(Whitespace {});
|
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
|
||||||
// END quicktour_init_pretok
|
// END quicktour_init_pretok
|
||||||
|
|
||||||
// START quicktour_train
|
// START quicktour_train
|
||||||
@ -157,7 +157,7 @@ fn quicktour() -> tokenizers::Result<()> {
|
|||||||
("[CLS]", tokenizer.token_to_id("[CLS]").unwrap()),
|
("[CLS]", tokenizer.token_to_id("[CLS]").unwrap()),
|
||||||
("[SEP]", tokenizer.token_to_id("[SEP]").unwrap()),
|
("[SEP]", tokenizer.token_to_id("[SEP]").unwrap()),
|
||||||
];
|
];
|
||||||
tokenizer.with_post_processor(
|
tokenizer.with_post_processor(Some(
|
||||||
TemplateProcessing::builder()
|
TemplateProcessing::builder()
|
||||||
.try_single("[CLS] $A [SEP]")
|
.try_single("[CLS] $A [SEP]")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -165,7 +165,7 @@ fn quicktour() -> tokenizers::Result<()> {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.special_tokens(special_tokens)
|
.special_tokens(special_tokens)
|
||||||
.build()?,
|
.build()?,
|
||||||
);
|
));
|
||||||
// END quicktour_init_template_processing
|
// END quicktour_init_template_processing
|
||||||
// START quicktour_print_special_tokens
|
// START quicktour_print_special_tokens
|
||||||
let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?;
|
let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?;
|
||||||
@ -261,7 +261,7 @@ fn pipeline() -> tokenizers::Result<()> {
|
|||||||
// END pipeline_test_normalizer
|
// END pipeline_test_normalizer
|
||||||
assert_eq!(normalized.get(), "Hello how are u?");
|
assert_eq!(normalized.get(), "Hello how are u?");
|
||||||
// START pipeline_replace_normalizer
|
// START pipeline_replace_normalizer
|
||||||
tokenizer.with_normalizer(normalizer);
|
tokenizer.with_normalizer(Some(normalizer));
|
||||||
// END pipeline_replace_normalizer
|
// END pipeline_replace_normalizer
|
||||||
// START pipeline_setup_pre_tokenizer
|
// START pipeline_setup_pre_tokenizer
|
||||||
use tokenizers::pre_tokenizers::whitespace::Whitespace;
|
use tokenizers::pre_tokenizers::whitespace::Whitespace;
|
||||||
@ -325,12 +325,12 @@ fn pipeline() -> tokenizers::Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
// START pipeline_replace_pre_tokenizer
|
// START pipeline_replace_pre_tokenizer
|
||||||
tokenizer.with_pre_tokenizer(pre_tokenizer);
|
tokenizer.with_pre_tokenizer(Some(pre_tokenizer));
|
||||||
// END pipeline_replace_pre_tokenizer
|
// END pipeline_replace_pre_tokenizer
|
||||||
// START pipeline_setup_processor
|
// START pipeline_setup_processor
|
||||||
use tokenizers::processors::template::TemplateProcessing;
|
use tokenizers::processors::template::TemplateProcessing;
|
||||||
|
|
||||||
tokenizer.with_post_processor(
|
tokenizer.with_post_processor(Some(
|
||||||
TemplateProcessing::builder()
|
TemplateProcessing::builder()
|
||||||
.try_single("[CLS] $A [SEP]")
|
.try_single("[CLS] $A [SEP]")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -339,7 +339,7 @@ fn pipeline() -> tokenizers::Result<()> {
|
|||||||
.special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)])
|
.special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)])
|
||||||
.build()
|
.build()
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
);
|
));
|
||||||
// END pipeline_setup_processor
|
// END pipeline_setup_processor
|
||||||
// START pipeline_test_decoding
|
// START pipeline_test_decoding
|
||||||
let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?;
|
let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?;
|
||||||
@ -375,21 +375,21 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
|
|||||||
use tokenizers::normalizers::utils::Sequence as NormalizerSequence;
|
use tokenizers::normalizers::utils::Sequence as NormalizerSequence;
|
||||||
use tokenizers::normalizers::{strip::StripAccents, unicode::NFD, utils::Lowercase};
|
use tokenizers::normalizers::{strip::StripAccents, unicode::NFD, utils::Lowercase};
|
||||||
|
|
||||||
bert_tokenizer.with_normalizer(NormalizerSequence::new(vec![
|
bert_tokenizer.with_normalizer(Some(NormalizerSequence::new(vec![
|
||||||
NFD.into(),
|
NFD.into(),
|
||||||
Lowercase.into(),
|
Lowercase.into(),
|
||||||
StripAccents.into(),
|
StripAccents.into(),
|
||||||
]));
|
])));
|
||||||
// END bert_setup_normalizer
|
// END bert_setup_normalizer
|
||||||
// START bert_setup_pre_tokenizer
|
// START bert_setup_pre_tokenizer
|
||||||
use tokenizers::pre_tokenizers::whitespace::Whitespace;
|
use tokenizers::pre_tokenizers::whitespace::Whitespace;
|
||||||
|
|
||||||
bert_tokenizer.with_pre_tokenizer(Whitespace {});
|
bert_tokenizer.with_pre_tokenizer(Some(Whitespace {}));
|
||||||
// END bert_setup_pre_tokenizer
|
// END bert_setup_pre_tokenizer
|
||||||
// START bert_setup_processor
|
// START bert_setup_processor
|
||||||
use tokenizers::processors::template::TemplateProcessing;
|
use tokenizers::processors::template::TemplateProcessing;
|
||||||
|
|
||||||
bert_tokenizer.with_post_processor(
|
bert_tokenizer.with_post_processor(Some(
|
||||||
TemplateProcessing::builder()
|
TemplateProcessing::builder()
|
||||||
.try_single("[CLS] $A [SEP]")
|
.try_single("[CLS] $A [SEP]")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -398,7 +398,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
|
|||||||
.special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)])
|
.special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)])
|
||||||
.build()
|
.build()
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
);
|
));
|
||||||
// END bert_setup_processor
|
// END bert_setup_processor
|
||||||
// START bert_train_tokenizer
|
// START bert_train_tokenizer
|
||||||
use tokenizers::models::{wordpiece::WordPieceTrainer, TrainerWrapper};
|
use tokenizers::models::{wordpiece::WordPieceTrainer, TrainerWrapper};
|
||||||
@ -450,7 +450,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
|
|||||||
// START bert_proper_decoding
|
// START bert_proper_decoding
|
||||||
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
|
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
|
||||||
|
|
||||||
bert_tokenizer.with_decoder(WordPieceDecoder::default());
|
bert_tokenizer.with_decoder(Some(WordPieceDecoder::default()));
|
||||||
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
|
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
|
||||||
// "welcome to the tokenizers library."
|
// "welcome to the tokenizers library."
|
||||||
// END bert_proper_decoding
|
// END bert_proper_decoding
|
||||||
|
@ -203,7 +203,7 @@ fn models() {
|
|||||||
fn tokenizer() {
|
fn tokenizer() {
|
||||||
let wordpiece = WordPiece::default();
|
let wordpiece = WordPiece::default();
|
||||||
let mut tokenizer = Tokenizer::new(wordpiece);
|
let mut tokenizer = Tokenizer::new(wordpiece);
|
||||||
tokenizer.with_normalizer(NFC);
|
tokenizer.with_normalizer(Some(NFC));
|
||||||
let ser = serde_json::to_string(&tokenizer).unwrap();
|
let ser = serde_json::to_string(&tokenizer).unwrap();
|
||||||
let _: Tokenizer = serde_json::from_str(&ser).unwrap();
|
let _: Tokenizer = serde_json::from_str(&ser).unwrap();
|
||||||
let unwrapped_nfc_tok: TokenizerImpl<
|
let unwrapped_nfc_tok: TokenizerImpl<
|
||||||
|
Reference in New Issue
Block a user