mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Support None
to reset pre_tokenizers and normalizers, and index sequences (#1590)
* initial commit * support None * fix clippy * cleanup * clean? * propagate to pre_tokenizer * fix test * fix rust tests * fix node * propagate to decoder and post processor * fix calls * lint * fmt * node be happy I am fixing you * initial commit * support None * fix clippy * cleanup * clean? * propagate to pre_tokenizer * fix test * fix rust tests * fix node * propagate to decoder and post processor * fix calls * lint * fmt * node be happy I am fixing you * add a small test * styling * style merge * fix merge test * fmt * nits * update tset
This commit is contained in:
@ -208,7 +208,7 @@ impl Tokenizer {
|
||||
.tokenizer
|
||||
.write()
|
||||
.unwrap()
|
||||
.with_pre_tokenizer((*pre_tokenizer).clone());
|
||||
.with_pre_tokenizer(Some((*pre_tokenizer).clone()));
|
||||
}
|
||||
|
||||
#[napi]
|
||||
@ -217,7 +217,7 @@ impl Tokenizer {
|
||||
.tokenizer
|
||||
.write()
|
||||
.unwrap()
|
||||
.with_decoder((*decoder).clone());
|
||||
.with_decoder(Some((*decoder).clone()));
|
||||
}
|
||||
|
||||
#[napi]
|
||||
@ -231,7 +231,7 @@ impl Tokenizer {
|
||||
.tokenizer
|
||||
.write()
|
||||
.unwrap()
|
||||
.with_post_processor((*post_processor).clone());
|
||||
.with_post_processor(Some((*post_processor).clone()));
|
||||
}
|
||||
|
||||
#[napi]
|
||||
@ -240,7 +240,7 @@ impl Tokenizer {
|
||||
.tokenizer
|
||||
.write()
|
||||
.unwrap()
|
||||
.with_normalizer((*normalizer).clone());
|
||||
.with_normalizer(Some((*normalizer).clone()));
|
||||
}
|
||||
|
||||
#[napi]
|
||||
|
@ -1,8 +1,6 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use pyo3::{exceptions, prelude::*};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use crate::error::ToPyResult;
|
||||
use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern};
|
||||
@ -354,6 +352,7 @@ impl PyNFKC {
|
||||
/// A list of Normalizer to be run as a sequence
|
||||
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Sequence")]
|
||||
pub struct PySequence {}
|
||||
|
||||
#[pymethods]
|
||||
impl PySequence {
|
||||
#[new]
|
||||
@ -380,6 +379,22 @@ impl PySequence {
|
||||
fn __len__(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
|
||||
match &self_.as_ref().normalizer {
|
||||
PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) {
|
||||
Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item)))
|
||||
.get_as_subtype(py),
|
||||
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||
"Index not found",
|
||||
)),
|
||||
},
|
||||
PyNormalizerTypeWrapper::Single(inner) => {
|
||||
PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(inner)))
|
||||
.get_as_subtype(py)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Lowercase Normalizer
|
||||
|
@ -463,6 +463,24 @@ impl PySequence {
|
||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> {
|
||||
PyTuple::new_bound(py, [PyList::empty_bound(py)])
|
||||
}
|
||||
|
||||
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
|
||||
match &self_.as_ref().pretok {
|
||||
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
|
||||
Some(item) => {
|
||||
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item)))
|
||||
.get_as_subtype(py)
|
||||
}
|
||||
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||
"Index not found",
|
||||
)),
|
||||
},
|
||||
PyPreTokenizerTypeWrapper::Single(inner) => {
|
||||
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner)))
|
||||
.get_as_subtype(py)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> {
|
||||
|
@ -1371,8 +1371,9 @@ impl PyTokenizer {
|
||||
|
||||
/// Set the :class:`~tokenizers.normalizers.Normalizer`
|
||||
#[setter]
|
||||
fn set_normalizer(&mut self, normalizer: PyRef<PyNormalizer>) {
|
||||
self.tokenizer.with_normalizer(normalizer.clone());
|
||||
fn set_normalizer(&mut self, normalizer: Option<PyRef<PyNormalizer>>) {
|
||||
let normalizer_option = normalizer.map(|norm| norm.clone());
|
||||
self.tokenizer.with_normalizer(normalizer_option);
|
||||
}
|
||||
|
||||
/// The `optional` :class:`~tokenizers.pre_tokenizers.PreTokenizer` in use by the Tokenizer
|
||||
@ -1387,8 +1388,9 @@ impl PyTokenizer {
|
||||
|
||||
/// Set the :class:`~tokenizers.normalizers.Normalizer`
|
||||
#[setter]
|
||||
fn set_pre_tokenizer(&mut self, pretok: PyRef<PyPreTokenizer>) {
|
||||
self.tokenizer.with_pre_tokenizer(pretok.clone());
|
||||
fn set_pre_tokenizer(&mut self, pretok: Option<PyRef<PyPreTokenizer>>) {
|
||||
self.tokenizer
|
||||
.with_pre_tokenizer(pretok.map(|pre| pre.clone()));
|
||||
}
|
||||
|
||||
/// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer
|
||||
@ -1403,8 +1405,9 @@ impl PyTokenizer {
|
||||
|
||||
/// Set the :class:`~tokenizers.processors.PostProcessor`
|
||||
#[setter]
|
||||
fn set_post_processor(&mut self, processor: PyRef<PyPostProcessor>) {
|
||||
self.tokenizer.with_post_processor(processor.clone());
|
||||
fn set_post_processor(&mut self, processor: Option<PyRef<PyPostProcessor>>) {
|
||||
self.tokenizer
|
||||
.with_post_processor(processor.map(|p| p.clone()));
|
||||
}
|
||||
|
||||
/// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer
|
||||
@ -1419,8 +1422,8 @@ impl PyTokenizer {
|
||||
|
||||
/// Set the :class:`~tokenizers.decoders.Decoder`
|
||||
#[setter]
|
||||
fn set_decoder(&mut self, decoder: PyRef<PyDecoder>) {
|
||||
self.tokenizer.with_decoder(decoder.clone());
|
||||
fn set_decoder(&mut self, decoder: Option<PyRef<PyDecoder>>) {
|
||||
self.tokenizer.with_decoder(decoder.map(|d| d.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1436,10 +1439,12 @@ mod test {
|
||||
#[test]
|
||||
fn serialize() {
|
||||
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
|
||||
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
|
||||
Arc::new(RwLock::new(NFKC.into())),
|
||||
Arc::new(RwLock::new(Lowercase.into())),
|
||||
])));
|
||||
tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(
|
||||
vec![
|
||||
Arc::new(RwLock::new(NFKC.into())),
|
||||
Arc::new(RwLock::new(Lowercase.into())),
|
||||
],
|
||||
))));
|
||||
|
||||
let tmp = NamedTempFile::new().unwrap().into_temp_path();
|
||||
tokenizer.save(&tmp, false).unwrap();
|
||||
@ -1450,10 +1455,12 @@ mod test {
|
||||
#[test]
|
||||
fn serde_pyo3() {
|
||||
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
|
||||
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
|
||||
Arc::new(RwLock::new(NFKC.into())),
|
||||
Arc::new(RwLock::new(Lowercase.into())),
|
||||
])));
|
||||
tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(
|
||||
vec![
|
||||
Arc::new(RwLock::new(NFKC.into())),
|
||||
Arc::new(RwLock::new(Lowercase.into())),
|
||||
],
|
||||
))));
|
||||
|
||||
let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap();
|
||||
assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))");
|
||||
|
@ -67,6 +67,14 @@ class TestSequence:
|
||||
output = normalizer.normalize_str(" HELLO ")
|
||||
assert output == "hello"
|
||||
|
||||
def test_items(self):
|
||||
normalizers = Sequence([BertNormalizer(True, True), Prepend()])
|
||||
assert normalizers[1].__class__ == Prepend
|
||||
normalizers[0].lowercase = False
|
||||
assert not normalizers[0].lowercase
|
||||
with pytest.raises(IndexError):
|
||||
print(normalizers[2])
|
||||
|
||||
|
||||
class TestLowercase:
|
||||
def test_instantiate(self):
|
||||
|
@ -169,6 +169,13 @@ class TestSequence:
|
||||
("?", (29, 30)),
|
||||
]
|
||||
|
||||
def test_items(self):
|
||||
pre_tokenizers = Sequence([Metaspace("a", "never", split=True), Punctuation()])
|
||||
assert pre_tokenizers[1].__class__ == Punctuation
|
||||
assert pre_tokenizers[0].__class__ == Metaspace
|
||||
pre_tokenizers[0].split = False
|
||||
assert not pre_tokenizers[0].split
|
||||
|
||||
|
||||
class TestDigits:
|
||||
def test_instantiate(self):
|
||||
|
@ -6,10 +6,11 @@ import pytest
|
||||
from tokenizers import AddedToken, Encoding, Tokenizer
|
||||
from tokenizers.implementations import BertWordPieceTokenizer
|
||||
from tokenizers.models import BPE, Model, Unigram
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.pre_tokenizers import ByteLevel, Metaspace
|
||||
from tokenizers.processors import RobertaProcessing, TemplateProcessing
|
||||
from tokenizers.normalizers import Strip, Lowercase, Sequence
|
||||
|
||||
|
||||
from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files
|
||||
|
||||
|
||||
@ -551,6 +552,16 @@ class TestTokenizer:
|
||||
assert output == "name is john"
|
||||
assert tokenizer.get_added_tokens_decoder()[0] == AddedToken("my", special=True)
|
||||
|
||||
def test_setting_to_none(self):
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.normalizer = Strip()
|
||||
tokenizer.normalizer = None
|
||||
assert tokenizer.normalizer == None
|
||||
|
||||
tokenizer.pre_tokenizer = Metaspace()
|
||||
tokenizer.pre_tokenizer = None
|
||||
assert tokenizer.pre_tokenizer == None
|
||||
|
||||
|
||||
class TestTokenizerRepr:
|
||||
def test_repr(self):
|
||||
|
Reference in New Issue
Block a user