mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
🚨 Support updating template processors (#1652)
* current updates * simplify * set_item works, but `tokenizer._tokenizer.post_processor[1].single = ["$0", "</s>"]` does not ! * fix: `normalizers` deserialization and other refactoring * fix: `pre_tokenizer` deserialization * feat: add `__len__` implementation for `normalizer::PySequence` * feat: add `__setitem__` impl for `normalizers::PySequence` * feat: add `__setitem__` impl to `pre_tokenizer::PySequence` * feat: add `__setitem__` impl to `post_processor::PySequence` * test: add normalizer sequence setter check * refactor: allow unused `processors::setter` macro * test: add `__setitem__` test for processors & pretok * refactor: `unwrap` -> `PyException::new_err()?` * refactor: fmt * refactor: remove unnecessary `pub` * feat(bindings): add missing getters & setters for pretoks * feat(bindings): add missing getters & setters for processors * refactor(bindings): rewrite RwLock poison error msg * refactor: remove debug print * feat(bindings): add description as to why custom deser is needed * feat: make post proc sequence elements mutable * fix(binding): serialization --------- Co-authored-by: Luc Georges <luc.sydney.georges@gmail.com>
This commit is contained in:
1
.github/workflows/python.yml
vendored
1
.github/workflows/python.yml
vendored
@ -55,7 +55,6 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
|
||||||
- name: Install Rust
|
- name: Install Rust
|
||||||
uses: actions-rs/toolchain@v1
|
uses: actions-rs/toolchain@v1
|
||||||
|
@ -14,7 +14,7 @@ serde = { version = "1.0", features = ["rc", "derive"] }
|
|||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
libc = "0.2"
|
libc = "0.2"
|
||||||
env_logger = "0.11"
|
env_logger = "0.11"
|
||||||
pyo3 = { version = "0.23", features = ["abi3", "abi3-py39"] }
|
pyo3 = { version = "0.23", features = ["abi3", "abi3-py39", "py-clone"] }
|
||||||
numpy = "0.23"
|
numpy = "0.23"
|
||||||
ndarray = "0.16"
|
ndarray = "0.16"
|
||||||
itertools = "0.12"
|
itertools = "0.12"
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
name = 'tokenizers'
|
name = 'tokenizers'
|
||||||
requires-python = '>=3.9'
|
requires-python = '>=3.9'
|
||||||
authors = [
|
authors = [
|
||||||
{name = 'Nicolas Patry', email = 'patry.nicolas@protonmail.com'},
|
{ name = 'Nicolas Patry', email = 'patry.nicolas@protonmail.com' },
|
||||||
{name = 'Anthony Moi', email = 'anthony@huggingface.co'}
|
{ name = 'Anthony Moi', email = 'anthony@huggingface.co' },
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 5 - Production/Stable",
|
"Development Status :: 5 - Production/Stable",
|
||||||
@ -21,12 +21,7 @@ classifiers = [
|
|||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
]
|
]
|
||||||
keywords = ["NLP", "tokenizer", "BPE", "transformer", "deep learning"]
|
keywords = ["NLP", "tokenizer", "BPE", "transformer", "deep learning"]
|
||||||
dynamic = [
|
dynamic = ['description', 'license', 'readme', 'version']
|
||||||
'description',
|
|
||||||
'license',
|
|
||||||
'readme',
|
|
||||||
'version',
|
|
||||||
]
|
|
||||||
dependencies = ["huggingface_hub>=0.16.4,<1.0"]
|
dependencies = ["huggingface_hub>=0.16.4,<1.0"]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
@ -58,16 +53,16 @@ target-version = ['py35']
|
|||||||
line-length = 119
|
line-length = 119
|
||||||
target-version = "py311"
|
target-version = "py311"
|
||||||
lint.ignore = [
|
lint.ignore = [
|
||||||
# a == None in tests vs is None.
|
# a == None in tests vs is None.
|
||||||
"E711",
|
"E711",
|
||||||
# a == False in tests vs is False.
|
# a == False in tests vs is False.
|
||||||
"E712",
|
"E712",
|
||||||
# try.. import except.. pattern without using the lib.
|
# try.. import except.. pattern without using the lib.
|
||||||
"F401",
|
"F401",
|
||||||
# Raw type equality is required in asserts
|
# Raw type equality is required in asserts
|
||||||
"E721",
|
"E721",
|
||||||
# Import order
|
# Import order
|
||||||
"E402",
|
"E402",
|
||||||
# Fixtures unused import
|
# Fixtures unused import
|
||||||
"F811",
|
"F811",
|
||||||
]
|
]
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
use pyo3::exceptions::PyException;
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
use pyo3::{exceptions, prelude::*};
|
use pyo3::{exceptions, prelude::*};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
@ -41,7 +42,7 @@ impl PyNormalizedStringMut<'_> {
|
|||||||
/// This class is not supposed to be instantiated directly. Instead, any implementation of a
|
/// This class is not supposed to be instantiated directly. Instead, any implementation of a
|
||||||
/// Normalizer will return an instance of this class when instantiated.
|
/// Normalizer will return an instance of this class when instantiated.
|
||||||
#[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)]
|
#[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)]
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(transparent)]
|
#[serde(transparent)]
|
||||||
pub struct PyNormalizer {
|
pub struct PyNormalizer {
|
||||||
pub(crate) normalizer: PyNormalizerTypeWrapper,
|
pub(crate) normalizer: PyNormalizerTypeWrapper,
|
||||||
@ -58,7 +59,11 @@ impl PyNormalizer {
|
|||||||
.into_pyobject(py)?
|
.into_pyobject(py)?
|
||||||
.into_any()
|
.into_any()
|
||||||
.into(),
|
.into(),
|
||||||
PyNormalizerTypeWrapper::Single(ref inner) => match &*inner.as_ref().read().unwrap() {
|
PyNormalizerTypeWrapper::Single(ref inner) => match &*inner
|
||||||
|
.as_ref()
|
||||||
|
.read()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))?
|
||||||
|
{
|
||||||
PyNormalizerWrapper::Custom(_) => {
|
PyNormalizerWrapper::Custom(_) => {
|
||||||
Py::new(py, base)?.into_pyobject(py)?.into_any().into()
|
Py::new(py, base)?.into_pyobject(py)?.into_any().into()
|
||||||
}
|
}
|
||||||
@ -218,7 +223,9 @@ macro_rules! getter {
|
|||||||
($self: ident, $variant: ident, $name: ident) => {{
|
($self: ident, $variant: ident, $name: ident) => {{
|
||||||
let super_ = $self.as_ref();
|
let super_ = $self.as_ref();
|
||||||
if let PyNormalizerTypeWrapper::Single(ref norm) = super_.normalizer {
|
if let PyNormalizerTypeWrapper::Single(ref norm) = super_.normalizer {
|
||||||
let wrapper = norm.read().unwrap();
|
let wrapper = norm.read().expect(
|
||||||
|
"RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer",
|
||||||
|
);
|
||||||
if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (&*wrapper) {
|
if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (&*wrapper) {
|
||||||
o.$name.clone()
|
o.$name.clone()
|
||||||
} else {
|
} else {
|
||||||
@ -234,7 +241,9 @@ macro_rules! setter {
|
|||||||
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
|
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
|
||||||
let super_ = $self.as_ref();
|
let super_ = $self.as_ref();
|
||||||
if let PyNormalizerTypeWrapper::Single(ref norm) = super_.normalizer {
|
if let PyNormalizerTypeWrapper::Single(ref norm) = super_.normalizer {
|
||||||
let mut wrapper = norm.write().unwrap();
|
let mut wrapper = norm.write().expect(
|
||||||
|
"RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer",
|
||||||
|
);
|
||||||
if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(ref mut o)) = *wrapper {
|
if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(ref mut o)) = *wrapper {
|
||||||
o.$name = $value;
|
o.$name = $value;
|
||||||
}
|
}
|
||||||
@ -410,25 +419,55 @@ impl PySequence {
|
|||||||
PyTuple::new(py, [PyList::empty(py)])
|
PyTuple::new(py, [PyList::empty(py)])
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __len__(&self) -> usize {
|
fn __len__(self_: PyRef<'_, Self>) -> usize {
|
||||||
0
|
match &self_.as_ref().normalizer {
|
||||||
|
PyNormalizerTypeWrapper::Sequence(inner) => inner.len(),
|
||||||
|
PyNormalizerTypeWrapper::Single(_) => 1,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
|
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
|
||||||
match &self_.as_ref().normalizer {
|
match &self_.as_ref().normalizer {
|
||||||
PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) {
|
PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) {
|
||||||
Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item)))
|
Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(item.clone()))
|
||||||
.get_as_subtype(py),
|
.get_as_subtype(py),
|
||||||
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||||
"Index not found",
|
"Index not found",
|
||||||
)),
|
)),
|
||||||
},
|
},
|
||||||
PyNormalizerTypeWrapper::Single(inner) => {
|
PyNormalizerTypeWrapper::Single(inner) => {
|
||||||
PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(inner)))
|
PyNormalizer::new(PyNormalizerTypeWrapper::Single(inner.clone())).get_as_subtype(py)
|
||||||
.get_as_subtype(py)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> {
|
||||||
|
let norm: PyNormalizer = value.extract()?;
|
||||||
|
let PyNormalizerTypeWrapper::Single(norm) = norm.normalizer else {
|
||||||
|
return Err(PyException::new_err("normalizer should not be a sequence"));
|
||||||
|
};
|
||||||
|
match &self_.as_ref().normalizer {
|
||||||
|
PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) {
|
||||||
|
Some(item) => {
|
||||||
|
*item
|
||||||
|
.write()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))? = norm
|
||||||
|
.read()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))?
|
||||||
|
.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||||
|
"Index not found",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
PyNormalizerTypeWrapper::Single(_) => {
|
||||||
|
return Err(PyException::new_err("normalizer is not a sequence"))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lowercase Normalizer
|
/// Lowercase Normalizer
|
||||||
@ -570,9 +609,31 @@ impl PyReplace {
|
|||||||
ToPyResult(Replace::new(pattern, content)).into_py()?.into(),
|
ToPyResult(Replace::new(pattern, content)).into_py()?.into(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_pattern(_self: PyRef<Self>) -> PyResult<()> {
|
||||||
|
Err(PyException::new_err("Cannot get pattern"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_pattern(_self: PyRef<Self>, _pattern: PyPattern) -> PyResult<()> {
|
||||||
|
Err(PyException::new_err(
|
||||||
|
"Cannot set pattern, please instantiate a new replace pattern instead",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_content(self_: PyRef<Self>) -> String {
|
||||||
|
getter!(self_, Replace, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_content(self_: PyRef<Self>, content: String) {
|
||||||
|
setter!(self_, Replace, content, content)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub(crate) struct CustomNormalizer {
|
pub(crate) struct CustomNormalizer {
|
||||||
inner: PyObject,
|
inner: PyObject,
|
||||||
}
|
}
|
||||||
@ -615,7 +676,7 @@ impl<'de> Deserialize<'de> for CustomNormalizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub(crate) enum PyNormalizerWrapper {
|
pub(crate) enum PyNormalizerWrapper {
|
||||||
Custom(CustomNormalizer),
|
Custom(CustomNormalizer),
|
||||||
@ -634,13 +695,27 @@ impl Serialize for PyNormalizerWrapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone)]
|
||||||
#[serde(untagged)]
|
|
||||||
pub(crate) enum PyNormalizerTypeWrapper {
|
pub(crate) enum PyNormalizerTypeWrapper {
|
||||||
Sequence(Vec<Arc<RwLock<PyNormalizerWrapper>>>),
|
Sequence(Vec<Arc<RwLock<PyNormalizerWrapper>>>),
|
||||||
Single(Arc<RwLock<PyNormalizerWrapper>>),
|
Single(Arc<RwLock<PyNormalizerWrapper>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// XXX: we need to manually implement deserialize here because of the structure of the
|
||||||
|
/// PyNormalizerTypeWrapper enum. Given the underlying PyNormalizerWrapper can contain a Sequence,
|
||||||
|
/// default deserialization will give us a PyNormalizerTypeWrapper::Single(Sequence) when we'd like
|
||||||
|
/// it to be PyNormalizerTypeWrapper::Sequence(// ...).
|
||||||
|
impl<'de> Deserialize<'de> for PyNormalizerTypeWrapper {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let wrapper = NormalizerWrapper::deserialize(deserializer)?;
|
||||||
|
let py_wrapper: PyNormalizerWrapper = wrapper.into();
|
||||||
|
Ok(py_wrapper.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Serialize for PyNormalizerTypeWrapper {
|
impl Serialize for PyNormalizerTypeWrapper {
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
@ -672,7 +747,17 @@ where
|
|||||||
I: Into<PyNormalizerWrapper>,
|
I: Into<PyNormalizerWrapper>,
|
||||||
{
|
{
|
||||||
fn from(norm: I) -> Self {
|
fn from(norm: I) -> Self {
|
||||||
PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm.into())))
|
let norm = norm.into();
|
||||||
|
match norm {
|
||||||
|
PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(seq)) => {
|
||||||
|
PyNormalizerTypeWrapper::Sequence(
|
||||||
|
seq.into_iter()
|
||||||
|
.map(|e| Arc::new(RwLock::new(PyNormalizerWrapper::Wrapped(e.clone()))))
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm))),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -690,10 +775,15 @@ where
|
|||||||
impl Normalizer for PyNormalizerTypeWrapper {
|
impl Normalizer for PyNormalizerTypeWrapper {
|
||||||
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
|
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
|
||||||
match self {
|
match self {
|
||||||
PyNormalizerTypeWrapper::Single(inner) => inner.read().unwrap().normalize(normalized),
|
PyNormalizerTypeWrapper::Single(inner) => inner
|
||||||
PyNormalizerTypeWrapper::Sequence(inner) => inner
|
.read()
|
||||||
.iter()
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))?
|
||||||
.try_for_each(|n| n.read().unwrap().normalize(normalized)),
|
.normalize(normalized),
|
||||||
|
PyNormalizerTypeWrapper::Sequence(inner) => inner.iter().try_for_each(|n| {
|
||||||
|
n.read()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))?
|
||||||
|
.normalize(normalized)
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -793,18 +883,14 @@ mod test {
|
|||||||
let normalizer: PyNormalizer = serde_json::from_str(&sequence_string).unwrap();
|
let normalizer: PyNormalizer = serde_json::from_str(&sequence_string).unwrap();
|
||||||
|
|
||||||
match normalizer.normalizer {
|
match normalizer.normalizer {
|
||||||
PyNormalizerTypeWrapper::Single(inner) => match &*inner.as_ref().read().unwrap() {
|
PyNormalizerTypeWrapper::Sequence(inner) => {
|
||||||
PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(sequence)) => {
|
assert_eq!(inner.len(), 1);
|
||||||
let normalizers = sequence.get_normalizers();
|
match *inner[0].as_ref().read().unwrap() {
|
||||||
assert_eq!(normalizers.len(), 1);
|
PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {}
|
||||||
match normalizers[0] {
|
_ => panic!("Expected NFKC"),
|
||||||
NormalizerWrapper::NFKC(_) => {}
|
};
|
||||||
_ => panic!("Expected NFKC"),
|
}
|
||||||
}
|
_ => panic!("Expected sequence"),
|
||||||
}
|
|
||||||
_ => panic!("Expected sequence"),
|
|
||||||
},
|
|
||||||
_ => panic!("Expected single"),
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
use pyo3::exceptions;
|
use pyo3::exceptions;
|
||||||
|
use pyo3::exceptions::PyException;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
use serde::ser::SerializeStruct;
|
use serde::ser::SerializeStruct;
|
||||||
@ -48,13 +49,17 @@ impl PyPreTokenizer {
|
|||||||
|
|
||||||
pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
|
pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
let base = self.clone();
|
let base = self.clone();
|
||||||
Ok(match &self.pretok {
|
Ok(match self.pretok {
|
||||||
PyPreTokenizerTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?
|
PyPreTokenizerTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?
|
||||||
.into_pyobject(py)?
|
.into_pyobject(py)?
|
||||||
.into_any()
|
.into_any()
|
||||||
.into(),
|
.into(),
|
||||||
PyPreTokenizerTypeWrapper::Single(ref inner) => {
|
PyPreTokenizerTypeWrapper::Single(ref inner) => {
|
||||||
match &*inner.as_ref().read().unwrap() {
|
match &*inner
|
||||||
|
.as_ref()
|
||||||
|
.read()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))?
|
||||||
|
{
|
||||||
PyPreTokenizerWrapper::Custom(_) => {
|
PyPreTokenizerWrapper::Custom(_) => {
|
||||||
Py::new(py, base)?.into_pyobject(py)?.into_any().into()
|
Py::new(py, base)?.into_pyobject(py)?.into_any().into()
|
||||||
}
|
}
|
||||||
@ -222,7 +227,7 @@ macro_rules! getter {
|
|||||||
let super_ = $self.as_ref();
|
let super_ = $self.as_ref();
|
||||||
if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok {
|
if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok {
|
||||||
if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref pretok)) =
|
if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref pretok)) =
|
||||||
*single.read().unwrap() {
|
*single.read().expect("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer") {
|
||||||
pretok.$($name)+
|
pretok.$($name)+
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
@ -238,7 +243,7 @@ macro_rules! setter {
|
|||||||
let super_ = $self.as_ref();
|
let super_ = $self.as_ref();
|
||||||
if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok {
|
if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok {
|
||||||
if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
|
if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
|
||||||
*single.write().unwrap()
|
*single.write().expect("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer")
|
||||||
{
|
{
|
||||||
pretok.$name = $value;
|
pretok.$name = $value;
|
||||||
}
|
}
|
||||||
@ -248,7 +253,7 @@ macro_rules! setter {
|
|||||||
let super_ = $self.as_ref();
|
let super_ = $self.as_ref();
|
||||||
if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok {
|
if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok {
|
||||||
if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
|
if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) =
|
||||||
*single.write().unwrap()
|
*single.write().expect("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer")
|
||||||
{
|
{
|
||||||
pretok.$name($value);
|
pretok.$name($value);
|
||||||
}
|
}
|
||||||
@ -292,6 +297,16 @@ impl PyByteLevel {
|
|||||||
setter!(self_, ByteLevel, use_regex, use_regex);
|
setter!(self_, ByteLevel, use_regex, use_regex);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_trim_offsets(self_: PyRef<Self>) -> bool {
|
||||||
|
getter!(self_, ByteLevel, trim_offsets)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_trim_offsets(self_: PyRef<Self>, trim_offsets: bool) {
|
||||||
|
setter!(self_, ByteLevel, trim_offsets, trim_offsets)
|
||||||
|
}
|
||||||
|
|
||||||
#[new]
|
#[new]
|
||||||
#[pyo3(signature = (add_prefix_space = true, use_regex = true, **_kwargs), text_signature = "(self, add_prefix_space=True, use_regex=True)")]
|
#[pyo3(signature = (add_prefix_space = true, use_regex = true, **_kwargs), text_signature = "(self, add_prefix_space=True, use_regex=True)")]
|
||||||
fn new(
|
fn new(
|
||||||
@ -392,6 +407,52 @@ impl PySplit {
|
|||||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyTuple>> {
|
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyTuple>> {
|
||||||
PyTuple::new(py, [" ", "removed"])
|
PyTuple::new(py, [" ", "removed"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_pattern(_self: PyRef<Self>) -> PyResult<()> {
|
||||||
|
Err(PyException::new_err("Cannot get pattern"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_pattern(_self: PyRef<Self>, _pattern: PyPattern) -> PyResult<()> {
|
||||||
|
Err(PyException::new_err(
|
||||||
|
"Cannot set pattern, please instantiate a new split pattern instead",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_behavior(self_: PyRef<Self>) -> String {
|
||||||
|
getter!(self_, Split, behavior).to_string().to_lowercase()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_behavior(self_: PyRef<Self>, behavior: String) -> PyResult<()> {
|
||||||
|
let behavior = match behavior.as_ref() {
|
||||||
|
"removed" => SplitDelimiterBehavior::Removed,
|
||||||
|
"isolated" => SplitDelimiterBehavior::Isolated,
|
||||||
|
"merged_with_previous" => SplitDelimiterBehavior::MergedWithPrevious,
|
||||||
|
"merged_with_next" => SplitDelimiterBehavior::MergedWithNext,
|
||||||
|
"contiguous" => SplitDelimiterBehavior::Contiguous,
|
||||||
|
_ => {
|
||||||
|
return Err(exceptions::PyValueError::new_err(
|
||||||
|
"Wrong value for SplitDelimiterBehavior, expected one of: \
|
||||||
|
`removed, isolated, merged_with_previous, merged_with_next, contiguous`",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
setter!(self_, Split, behavior, behavior);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_invert(self_: PyRef<Self>) -> bool {
|
||||||
|
getter!(self_, Split, invert)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_invert(self_: PyRef<Self>, invert: bool) {
|
||||||
|
setter!(self_, Split, invert, invert)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This pre-tokenizer simply splits on the provided char. Works like `.split(delimiter)`
|
/// This pre-tokenizer simply splits on the provided char. Works like `.split(delimiter)`
|
||||||
@ -458,6 +519,32 @@ impl PyPunctuation {
|
|||||||
fn new(behavior: PySplitDelimiterBehavior) -> (Self, PyPreTokenizer) {
|
fn new(behavior: PySplitDelimiterBehavior) -> (Self, PyPreTokenizer) {
|
||||||
(PyPunctuation {}, Punctuation::new(behavior.into()).into())
|
(PyPunctuation {}, Punctuation::new(behavior.into()).into())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_behavior(self_: PyRef<Self>) -> String {
|
||||||
|
getter!(self_, Punctuation, behavior)
|
||||||
|
.to_string()
|
||||||
|
.to_lowercase()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_behavior(self_: PyRef<Self>, behavior: String) -> PyResult<()> {
|
||||||
|
let behavior = match behavior.as_ref() {
|
||||||
|
"removed" => SplitDelimiterBehavior::Removed,
|
||||||
|
"isolated" => SplitDelimiterBehavior::Isolated,
|
||||||
|
"merged_with_previous" => SplitDelimiterBehavior::MergedWithPrevious,
|
||||||
|
"merged_with_next" => SplitDelimiterBehavior::MergedWithNext,
|
||||||
|
"contiguous" => SplitDelimiterBehavior::Contiguous,
|
||||||
|
_ => {
|
||||||
|
return Err(exceptions::PyValueError::new_err(
|
||||||
|
"Wrong value for SplitDelimiterBehavior, expected one of: \
|
||||||
|
`removed, isolated, merged_with_previous, merged_with_next, contiguous`",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
setter!(self_, Punctuation, behavior, behavior);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This pre-tokenizer composes other pre_tokenizers and applies them in sequence
|
/// This pre-tokenizer composes other pre_tokenizers and applies them in sequence
|
||||||
@ -491,20 +578,47 @@ impl PySequence {
|
|||||||
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
|
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
|
||||||
match &self_.as_ref().pretok {
|
match &self_.as_ref().pretok {
|
||||||
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
|
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
|
||||||
Some(item) => {
|
Some(item) => PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(item.clone()))
|
||||||
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item)))
|
.get_as_subtype(py),
|
||||||
.get_as_subtype(py)
|
|
||||||
}
|
|
||||||
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||||
"Index not found",
|
"Index not found",
|
||||||
)),
|
)),
|
||||||
},
|
},
|
||||||
PyPreTokenizerTypeWrapper::Single(inner) => {
|
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||||
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner)))
|
"This processor is not a Sequence, it does not support __getitem__",
|
||||||
.get_as_subtype(py)
|
)),
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> {
|
||||||
|
let pretok: PyPreTokenizer = value.extract()?;
|
||||||
|
let PyPreTokenizerTypeWrapper::Single(pretok) = pretok.pretok else {
|
||||||
|
return Err(PyException::new_err(
|
||||||
|
"pre tokenizer should not be a sequence",
|
||||||
|
));
|
||||||
|
};
|
||||||
|
match &self_.as_ref().pretok {
|
||||||
|
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
|
||||||
|
Some(item) => {
|
||||||
|
*item
|
||||||
|
.write()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))? = (*pretok
|
||||||
|
.read()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))?)
|
||||||
|
.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||||
|
"Index not found",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
PyPreTokenizerTypeWrapper::Single(_) => {
|
||||||
|
return Err(PyException::new_err("pre tokenizer is not a sequence"))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> {
|
pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> {
|
||||||
@ -565,13 +679,7 @@ impl PyMetaspace {
|
|||||||
#[getter]
|
#[getter]
|
||||||
fn get_prepend_scheme(self_: PyRef<Self>) -> String {
|
fn get_prepend_scheme(self_: PyRef<Self>) -> String {
|
||||||
// Assuming Metaspace has a method to get the prepend_scheme as a string
|
// Assuming Metaspace has a method to get the prepend_scheme as a string
|
||||||
let scheme: PrependScheme = getter!(self_, Metaspace, get_prepend_scheme());
|
getter!(self_, Metaspace, get_prepend_scheme()).to_string()
|
||||||
match scheme {
|
|
||||||
PrependScheme::First => "first",
|
|
||||||
PrependScheme::Never => "never",
|
|
||||||
PrependScheme::Always => "always",
|
|
||||||
}
|
|
||||||
.to_string()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[setter]
|
#[setter]
|
||||||
@ -642,6 +750,7 @@ impl PyUnicodeScripts {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub(crate) struct CustomPreTokenizer {
|
pub(crate) struct CustomPreTokenizer {
|
||||||
inner: PyObject,
|
inner: PyObject,
|
||||||
}
|
}
|
||||||
@ -685,7 +794,7 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Clone, Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub(crate) enum PyPreTokenizerWrapper {
|
pub(crate) enum PyPreTokenizerWrapper {
|
||||||
Custom(CustomPreTokenizer),
|
Custom(CustomPreTokenizer),
|
||||||
@ -704,13 +813,23 @@ impl Serialize for PyPreTokenizerWrapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone)]
|
||||||
#[serde(untagged)]
|
|
||||||
pub(crate) enum PyPreTokenizerTypeWrapper {
|
pub(crate) enum PyPreTokenizerTypeWrapper {
|
||||||
Sequence(Vec<Arc<RwLock<PyPreTokenizerWrapper>>>),
|
Sequence(Vec<Arc<RwLock<PyPreTokenizerWrapper>>>),
|
||||||
Single(Arc<RwLock<PyPreTokenizerWrapper>>),
|
Single(Arc<RwLock<PyPreTokenizerWrapper>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for PyPreTokenizerTypeWrapper {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let wrapper = PreTokenizerWrapper::deserialize(deserializer)?;
|
||||||
|
let py_wrapper: PyPreTokenizerWrapper = wrapper.into();
|
||||||
|
Ok(py_wrapper.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Serialize for PyPreTokenizerTypeWrapper {
|
impl Serialize for PyPreTokenizerTypeWrapper {
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
@ -742,7 +861,17 @@ where
|
|||||||
I: Into<PyPreTokenizerWrapper>,
|
I: Into<PyPreTokenizerWrapper>,
|
||||||
{
|
{
|
||||||
fn from(pretok: I) -> Self {
|
fn from(pretok: I) -> Self {
|
||||||
PyPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok.into())))
|
let pretok = pretok.into();
|
||||||
|
match pretok {
|
||||||
|
PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::Sequence(seq)) => {
|
||||||
|
PyPreTokenizerTypeWrapper::Sequence(
|
||||||
|
seq.into_iter()
|
||||||
|
.map(|e| Arc::new(RwLock::new(PyPreTokenizerWrapper::Wrapped(e.clone()))))
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => PyPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok))),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -760,10 +889,15 @@ where
|
|||||||
impl PreTokenizer for PyPreTokenizerTypeWrapper {
|
impl PreTokenizer for PyPreTokenizerTypeWrapper {
|
||||||
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
|
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
|
||||||
match self {
|
match self {
|
||||||
PyPreTokenizerTypeWrapper::Single(inner) => inner.read().unwrap().pre_tokenize(pretok),
|
PyPreTokenizerTypeWrapper::Single(inner) => inner
|
||||||
PyPreTokenizerTypeWrapper::Sequence(inner) => inner
|
.read()
|
||||||
.iter()
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))?
|
||||||
.try_for_each(|n| n.read().unwrap().pre_tokenize(pretok)),
|
.pre_tokenize(pretok),
|
||||||
|
PyPreTokenizerTypeWrapper::Sequence(inner) => inner.iter().try_for_each(|n| {
|
||||||
|
n.read()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))?
|
||||||
|
.pre_tokenize(pretok)
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,17 +1,20 @@
|
|||||||
use std::convert::TryInto;
|
use std::convert::TryInto;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::RwLock;
|
||||||
use pyo3::exceptions;
|
|
||||||
use pyo3::prelude::*;
|
|
||||||
use pyo3::types::*;
|
|
||||||
|
|
||||||
use crate::encoding::PyEncoding;
|
use crate::encoding::PyEncoding;
|
||||||
use crate::error::ToPyResult;
|
use crate::error::ToPyResult;
|
||||||
|
use pyo3::exceptions;
|
||||||
|
use pyo3::exceptions::PyException;
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::types::*;
|
||||||
|
use serde::ser::SerializeStruct;
|
||||||
|
use serde::Deserializer;
|
||||||
|
use serde::Serializer;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tk::processors::bert::BertProcessing;
|
use tk::processors::bert::BertProcessing;
|
||||||
use tk::processors::byte_level::ByteLevel;
|
use tk::processors::byte_level::ByteLevel;
|
||||||
use tk::processors::roberta::RobertaProcessing;
|
use tk::processors::roberta::RobertaProcessing;
|
||||||
use tk::processors::sequence::Sequence;
|
|
||||||
use tk::processors::template::{SpecialToken, Template};
|
use tk::processors::template::{SpecialToken, Template};
|
||||||
use tk::processors::PostProcessorWrapper;
|
use tk::processors::PostProcessorWrapper;
|
||||||
use tk::{Encoding, PostProcessor};
|
use tk::{Encoding, PostProcessor};
|
||||||
@ -30,42 +33,67 @@ use tokenizers as tk;
|
|||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
#[serde(transparent)]
|
#[serde(transparent)]
|
||||||
pub struct PyPostProcessor {
|
pub struct PyPostProcessor {
|
||||||
pub processor: Arc<PostProcessorWrapper>,
|
processor: PyPostProcessorTypeWrapper,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> From<I> for PyPostProcessor
|
||||||
|
where
|
||||||
|
I: Into<PostProcessorWrapper>,
|
||||||
|
{
|
||||||
|
fn from(processor: I) -> Self {
|
||||||
|
PyPostProcessor {
|
||||||
|
processor: processor.into().into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PyPostProcessor {
|
impl PyPostProcessor {
|
||||||
pub fn new(processor: Arc<PostProcessorWrapper>) -> Self {
|
pub(crate) fn new(processor: PyPostProcessorTypeWrapper) -> Self {
|
||||||
PyPostProcessor { processor }
|
PyPostProcessor { processor }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
|
pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
let base = self.clone();
|
let base = self.clone();
|
||||||
Ok(match self.processor.as_ref() {
|
Ok(
|
||||||
PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?
|
match self.processor {
|
||||||
|
PyPostProcessorTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?
|
||||||
.into_pyobject(py)?
|
.into_pyobject(py)?
|
||||||
.into_any()
|
.into_any()
|
||||||
.into(),
|
.into(),
|
||||||
PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?
|
PyPostProcessorTypeWrapper::Single(ref inner) => {
|
||||||
.into_pyobject(py)?
|
|
||||||
.into_any()
|
match &*inner.read().map_err(|_| {
|
||||||
.into(),
|
PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor")
|
||||||
PostProcessorWrapper::Roberta(_) => Py::new(py, (PyRobertaProcessing {}, base))?
|
})? {
|
||||||
.into_pyobject(py)?
|
PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?
|
||||||
.into_any()
|
.into_pyobject(py)?
|
||||||
.into(),
|
.into_any()
|
||||||
PostProcessorWrapper::Template(_) => Py::new(py, (PyTemplateProcessing {}, base))?
|
.into(),
|
||||||
.into_pyobject(py)?
|
PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?
|
||||||
.into_any()
|
.into_pyobject(py)?
|
||||||
.into(),
|
.into_any()
|
||||||
PostProcessorWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?
|
.into(),
|
||||||
.into_pyobject(py)?
|
PostProcessorWrapper::Roberta(_) => Py::new(py, (PyRobertaProcessing {}, base))?
|
||||||
.into_any()
|
.into_pyobject(py)?
|
||||||
.into(),
|
.into_any()
|
||||||
})
|
.into(),
|
||||||
|
PostProcessorWrapper::Template(_) => Py::new(py, (PyTemplateProcessing {}, base))?
|
||||||
|
.into_pyobject(py)?
|
||||||
|
.into_any()
|
||||||
|
.into(),
|
||||||
|
PostProcessorWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?
|
||||||
|
.into_pyobject(py)?
|
||||||
|
.into_any()
|
||||||
|
.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PostProcessor for PyPostProcessor {
|
impl PostProcessor for PyPostProcessor {
|
||||||
|
// TODO: update signature to `tk::Result<usize>`
|
||||||
fn added_tokens(&self, is_pair: bool) -> usize {
|
fn added_tokens(&self, is_pair: bool) -> usize {
|
||||||
self.processor.added_tokens(is_pair)
|
self.processor.added_tokens(is_pair)
|
||||||
}
|
}
|
||||||
@ -83,7 +111,7 @@ impl PostProcessor for PyPostProcessor {
|
|||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyPostProcessor {
|
impl PyPostProcessor {
|
||||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||||
let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| {
|
let data = serde_json::to_string(&self.processor).map_err(|e| {
|
||||||
exceptions::PyException::new_err(format!(
|
exceptions::PyException::new_err(format!(
|
||||||
"Error while attempting to pickle PostProcessor: {}",
|
"Error while attempting to pickle PostProcessor: {}",
|
||||||
e
|
e
|
||||||
@ -116,8 +144,8 @@ impl PyPostProcessor {
|
|||||||
/// Returns:
|
/// Returns:
|
||||||
/// :obj:`int`: The number of tokens to add
|
/// :obj:`int`: The number of tokens to add
|
||||||
#[pyo3(text_signature = "(self, is_pair)")]
|
#[pyo3(text_signature = "(self, is_pair)")]
|
||||||
fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
|
fn num_special_tokens_to_add(&self, is_pair: bool) -> PyResult<usize> {
|
||||||
self.processor.added_tokens(is_pair)
|
Ok(self.processor.added_tokens(is_pair))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Post-process the given encodings, generating the final one
|
/// Post-process the given encodings, generating the final one
|
||||||
@ -162,6 +190,132 @@ impl PyPostProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
macro_rules! getter {
|
||||||
|
($self: ident, $variant: ident, $($name: tt)+) => {{
|
||||||
|
let super_ = $self.as_ref();
|
||||||
|
if let PyPostProcessorTypeWrapper::Single(ref single) = super_.processor {
|
||||||
|
if let PostProcessorWrapper::$variant(ref post) = *single.read().expect(
|
||||||
|
"RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor"
|
||||||
|
) {
|
||||||
|
post.$($name)+
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! setter {
|
||||||
|
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
|
||||||
|
let super_ = $self.as_ref();
|
||||||
|
if let PyPostProcessorTypeWrapper::Single(ref single) = super_.processor {
|
||||||
|
if let PostProcessorWrapper::$variant(ref mut post) = *single.write().expect(
|
||||||
|
"RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor",
|
||||||
|
) {
|
||||||
|
post.$name = $value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}};
|
||||||
|
($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
|
||||||
|
let super_ = $self.as_ref();
|
||||||
|
if let PyPostProcessorTypeWrapper::Single(ref single) = super_.processor {
|
||||||
|
if let PostProcessorWrapper::$variant(ref mut post) = *single.write().expect(
|
||||||
|
"RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor",
|
||||||
|
) {
|
||||||
|
post.$name($value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub(crate) enum PyPostProcessorTypeWrapper {
|
||||||
|
Sequence(Vec<Arc<RwLock<PostProcessorWrapper>>>),
|
||||||
|
Single(Arc<RwLock<PostProcessorWrapper>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PostProcessor for PyPostProcessorTypeWrapper {
|
||||||
|
fn added_tokens(&self, is_pair: bool) -> usize {
|
||||||
|
match self {
|
||||||
|
PyPostProcessorTypeWrapper::Single(inner) => inner
|
||||||
|
.read()
|
||||||
|
.expect("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor")
|
||||||
|
.added_tokens(is_pair),
|
||||||
|
PyPostProcessorTypeWrapper::Sequence(inner) => inner.iter().map(|p| {
|
||||||
|
p.read()
|
||||||
|
.expect("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor")
|
||||||
|
.added_tokens(is_pair)
|
||||||
|
}).sum::<usize>(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process_encodings(
|
||||||
|
&self,
|
||||||
|
mut encodings: Vec<Encoding>,
|
||||||
|
add_special_tokens: bool,
|
||||||
|
) -> tk::Result<Vec<Encoding>> {
|
||||||
|
match self {
|
||||||
|
PyPostProcessorTypeWrapper::Single(inner) => inner
|
||||||
|
.read()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))?
|
||||||
|
.process_encodings(encodings, add_special_tokens),
|
||||||
|
PyPostProcessorTypeWrapper::Sequence(inner) => {
|
||||||
|
for processor in inner.iter() {
|
||||||
|
encodings = processor
|
||||||
|
.read()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))?
|
||||||
|
.process_encodings(encodings, add_special_tokens)?;
|
||||||
|
}
|
||||||
|
Ok(encodings)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for PyPostProcessorTypeWrapper {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let wrapper = PostProcessorWrapper::deserialize(deserializer)?;
|
||||||
|
Ok(wrapper.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for PyPostProcessorTypeWrapper {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
PyPostProcessorTypeWrapper::Sequence(seq) => {
|
||||||
|
let mut ser = serializer.serialize_struct("Sequence", 2)?;
|
||||||
|
ser.serialize_field("type", "Sequence")?;
|
||||||
|
ser.serialize_field("processors", seq)?;
|
||||||
|
ser.end()
|
||||||
|
}
|
||||||
|
PyPostProcessorTypeWrapper::Single(inner) => inner.serialize(serializer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> From<I> for PyPostProcessorTypeWrapper
|
||||||
|
where
|
||||||
|
I: Into<PostProcessorWrapper>,
|
||||||
|
{
|
||||||
|
fn from(processor: I) -> Self {
|
||||||
|
let processor = processor.into();
|
||||||
|
match processor {
|
||||||
|
PostProcessorWrapper::Sequence(seq) => PyPostProcessorTypeWrapper::Sequence(
|
||||||
|
seq.into_iter().map(|p| Arc::new(RwLock::new(p))).collect(),
|
||||||
|
),
|
||||||
|
_ => PyPostProcessorTypeWrapper::Single(Arc::new(RwLock::new(processor.clone()))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// This post-processor takes care of adding the special tokens needed by
|
/// This post-processor takes care of adding the special tokens needed by
|
||||||
/// a Bert model:
|
/// a Bert model:
|
||||||
///
|
///
|
||||||
@ -181,15 +335,46 @@ impl PyBertProcessing {
|
|||||||
#[new]
|
#[new]
|
||||||
#[pyo3(text_signature = "(self, sep, cls)")]
|
#[pyo3(text_signature = "(self, sep, cls)")]
|
||||||
fn new(sep: (String, u32), cls: (String, u32)) -> (Self, PyPostProcessor) {
|
fn new(sep: (String, u32), cls: (String, u32)) -> (Self, PyPostProcessor) {
|
||||||
(
|
(PyBertProcessing {}, BertProcessing::new(sep, cls).into())
|
||||||
PyBertProcessing {},
|
|
||||||
PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into())),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyTuple>> {
|
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyTuple>> {
|
||||||
PyTuple::new(py, [("", 0), ("", 0)])
|
PyTuple::new(py, [("", 0), ("", 0)])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_sep(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
|
||||||
|
let py = self_.py();
|
||||||
|
let (tok, id) = getter!(self_, Bert, get_sep_copy());
|
||||||
|
PyTuple::new(
|
||||||
|
py,
|
||||||
|
Vec::<PyObject>::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_sep(self_: PyRef<Self>, sep: Bound<'_, PyTuple>) -> PyResult<()> {
|
||||||
|
let sep = sep.extract()?;
|
||||||
|
setter!(self_, Bert, sep, sep);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_cls(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
|
||||||
|
let py = self_.py();
|
||||||
|
let (tok, id) = getter!(self_, Bert, get_cls_copy());
|
||||||
|
PyTuple::new(
|
||||||
|
py,
|
||||||
|
Vec::<PyObject>::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_cls(self_: PyRef<Self>, cls: Bound<'_, PyTuple>) -> PyResult<()> {
|
||||||
|
let cls = cls.extract()?;
|
||||||
|
setter!(self_, Bert, cls, cls);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This post-processor takes care of adding the special tokens needed by
|
/// This post-processor takes care of adding the special tokens needed by
|
||||||
@ -231,15 +416,66 @@ impl PyRobertaProcessing {
|
|||||||
let proc = RobertaProcessing::new(sep, cls)
|
let proc = RobertaProcessing::new(sep, cls)
|
||||||
.trim_offsets(trim_offsets)
|
.trim_offsets(trim_offsets)
|
||||||
.add_prefix_space(add_prefix_space);
|
.add_prefix_space(add_prefix_space);
|
||||||
(
|
(PyRobertaProcessing {}, proc.into())
|
||||||
PyRobertaProcessing {},
|
|
||||||
PyPostProcessor::new(Arc::new(proc.into())),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyTuple>> {
|
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyTuple>> {
|
||||||
PyTuple::new(py, [("", 0), ("", 0)])
|
PyTuple::new(py, [("", 0), ("", 0)])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_sep(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
|
||||||
|
let py = self_.py();
|
||||||
|
let (tok, id) = getter!(self_, Roberta, get_sep_copy());
|
||||||
|
PyTuple::new(
|
||||||
|
py,
|
||||||
|
Vec::<PyObject>::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_sep(self_: PyRef<Self>, sep: Bound<'_, PyTuple>) -> PyResult<()> {
|
||||||
|
let sep = sep.extract()?;
|
||||||
|
setter!(self_, Roberta, sep, sep);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_cls(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
|
||||||
|
let py = self_.py();
|
||||||
|
let (tok, id) = getter!(self_, Roberta, get_cls_copy());
|
||||||
|
PyTuple::new(
|
||||||
|
py,
|
||||||
|
Vec::<PyObject>::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_cls(self_: PyRef<Self>, cls: Bound<'_, PyTuple>) -> PyResult<()> {
|
||||||
|
let cls = cls.extract()?;
|
||||||
|
setter!(self_, Roberta, cls, cls);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_trim_offsets(self_: PyRef<Self>) -> bool {
|
||||||
|
getter!(self_, Roberta, trim_offsets)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_trim_offsets(self_: PyRef<Self>, trim_offsets: bool) {
|
||||||
|
setter!(self_, Roberta, trim_offsets, trim_offsets)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_add_prefix_space(self_: PyRef<Self>) -> bool {
|
||||||
|
getter!(self_, Roberta, add_prefix_space)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_add_prefix_space(self_: PyRef<Self>, add_prefix_space: bool) {
|
||||||
|
setter!(self_, Roberta, add_prefix_space, add_prefix_space)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This post-processor takes care of trimming the offsets.
|
/// This post-processor takes care of trimming the offsets.
|
||||||
@ -255,21 +491,58 @@ pub struct PyByteLevel {}
|
|||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyByteLevel {
|
impl PyByteLevel {
|
||||||
#[new]
|
#[new]
|
||||||
#[pyo3(signature = (trim_offsets = None, **_kwargs), text_signature = "(self, trim_offsets=True)")]
|
#[pyo3(signature = (add_prefix_space = None, trim_offsets = None, use_regex = None, **_kwargs), text_signature = "(self, trim_offsets=True)")]
|
||||||
fn new(
|
fn new(
|
||||||
|
add_prefix_space: Option<bool>,
|
||||||
trim_offsets: Option<bool>,
|
trim_offsets: Option<bool>,
|
||||||
|
use_regex: Option<bool>,
|
||||||
_kwargs: Option<&Bound<'_, PyDict>>,
|
_kwargs: Option<&Bound<'_, PyDict>>,
|
||||||
) -> (Self, PyPostProcessor) {
|
) -> (Self, PyPostProcessor) {
|
||||||
let mut byte_level = ByteLevel::default();
|
let mut byte_level = ByteLevel::default();
|
||||||
|
|
||||||
|
if let Some(aps) = add_prefix_space {
|
||||||
|
byte_level = byte_level.add_prefix_space(aps);
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(to) = trim_offsets {
|
if let Some(to) = trim_offsets {
|
||||||
byte_level = byte_level.trim_offsets(to);
|
byte_level = byte_level.trim_offsets(to);
|
||||||
}
|
}
|
||||||
|
|
||||||
(
|
if let Some(ur) = use_regex {
|
||||||
PyByteLevel {},
|
byte_level = byte_level.use_regex(ur);
|
||||||
PyPostProcessor::new(Arc::new(byte_level.into())),
|
}
|
||||||
)
|
|
||||||
|
(PyByteLevel {}, byte_level.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_add_prefix_space(self_: PyRef<Self>) -> bool {
|
||||||
|
getter!(self_, ByteLevel, add_prefix_space)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_add_prefix_space(self_: PyRef<Self>, add_prefix_space: bool) {
|
||||||
|
setter!(self_, ByteLevel, add_prefix_space, add_prefix_space)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_trim_offsets(self_: PyRef<Self>) -> bool {
|
||||||
|
getter!(self_, ByteLevel, trim_offsets)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_trim_offsets(self_: PyRef<Self>, trim_offsets: bool) {
|
||||||
|
setter!(self_, ByteLevel, trim_offsets, trim_offsets)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_use_regex(self_: PyRef<Self>) -> bool {
|
||||||
|
getter!(self_, ByteLevel, use_regex)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_use_regex(self_: PyRef<Self>, use_regex: bool) {
|
||||||
|
setter!(self_, ByteLevel, use_regex, use_regex)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -430,10 +703,26 @@ impl PyTemplateProcessing {
|
|||||||
.build()
|
.build()
|
||||||
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
|
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
|
||||||
|
|
||||||
Ok((
|
Ok((PyTemplateProcessing {}, processor.into()))
|
||||||
PyTemplateProcessing {},
|
}
|
||||||
PyPostProcessor::new(Arc::new(processor.into())),
|
|
||||||
))
|
#[getter]
|
||||||
|
fn get_single(self_: PyRef<Self>) -> String {
|
||||||
|
getter!(self_, Template, get_single())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_single(self_: PyRef<Self>, single: PyTemplate) -> PyResult<()> {
|
||||||
|
let template: Template = Template::from(single);
|
||||||
|
let super_ = self_.as_ref();
|
||||||
|
if let PyPostProcessorTypeWrapper::Single(ref inner) = super_.processor {
|
||||||
|
if let PostProcessorWrapper::Template(ref mut post) = *inner
|
||||||
|
.write()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor"))? {
|
||||||
|
post.set_single(template);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -444,27 +733,79 @@ impl PyTemplateProcessing {
|
|||||||
/// The processors that need to be chained
|
/// The processors that need to be chained
|
||||||
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name = "Sequence")]
|
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name = "Sequence")]
|
||||||
pub struct PySequence {}
|
pub struct PySequence {}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PySequence {
|
impl PySequence {
|
||||||
#[new]
|
#[new]
|
||||||
#[pyo3(signature = (processors_py), text_signature = "(self, processors)")]
|
#[pyo3(signature = (processors_py), text_signature = "(self, processors)")]
|
||||||
fn new(processors_py: &Bound<'_, PyList>) -> (Self, PyPostProcessor) {
|
fn new(processors_py: &Bound<'_, PyList>) -> PyResult<(Self, PyPostProcessor)> {
|
||||||
let mut processors: Vec<PostProcessorWrapper> = Vec::with_capacity(processors_py.len());
|
let mut processors = Vec::with_capacity(processors_py.len());
|
||||||
for n in processors_py.iter() {
|
for n in processors_py.iter() {
|
||||||
let processor: PyRef<PyPostProcessor> = n.extract().unwrap();
|
let processor: PyRef<PyPostProcessor> = n.extract()?;
|
||||||
let processor = processor.processor.as_ref();
|
match &processor.processor {
|
||||||
processors.push(processor.clone());
|
PyPostProcessorTypeWrapper::Sequence(inner) => {
|
||||||
|
processors.extend(inner.iter().cloned())
|
||||||
|
}
|
||||||
|
PyPostProcessorTypeWrapper::Single(inner) => processors.push(inner.clone()),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let sequence_processor = Sequence::new(processors);
|
Ok((
|
||||||
(
|
|
||||||
PySequence {},
|
PySequence {},
|
||||||
PyPostProcessor::new(Arc::new(PostProcessorWrapper::Sequence(sequence_processor))),
|
PyPostProcessor::new(PyPostProcessorTypeWrapper::Sequence(processors)),
|
||||||
)
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyTuple>> {
|
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyTuple>> {
|
||||||
PyTuple::new(py, [PyList::empty(py)])
|
PyTuple::new(py, [PyList::empty(py)])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
|
||||||
|
match &self_.as_ref().processor {
|
||||||
|
PyPostProcessorTypeWrapper::Sequence(ref inner) => match inner.get(index) {
|
||||||
|
Some(item) => {
|
||||||
|
PyPostProcessor::new(PyPostProcessorTypeWrapper::Single(item.clone()))
|
||||||
|
.get_as_subtype(py)
|
||||||
|
}
|
||||||
|
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||||
|
"Index not found",
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||||
|
"This processor is not a Sequence, it does not support __getitem__",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> {
|
||||||
|
let processor: PyPostProcessor = value.extract()?;
|
||||||
|
let PyPostProcessorTypeWrapper::Single(processor) = processor.processor else {
|
||||||
|
return Err(PyException::new_err("processor should not be a sequence"));
|
||||||
|
};
|
||||||
|
|
||||||
|
match &self_.as_ref().processor {
|
||||||
|
PyPostProcessorTypeWrapper::Sequence(inner) => match inner.get(index) {
|
||||||
|
Some(item) => {
|
||||||
|
*item
|
||||||
|
.write()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor"))? = processor
|
||||||
|
.read()
|
||||||
|
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor"))?
|
||||||
|
.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
|
||||||
|
"Index not found",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
return Err(PyException::new_err(
|
||||||
|
"This processor is not a Sequence, it does not support __setitem__",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Processors Module
|
/// Processors Module
|
||||||
@ -481,20 +822,20 @@ pub fn processors(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use tk::processors::bert::BertProcessing;
|
use tk::processors::bert::BertProcessing;
|
||||||
use tk::processors::PostProcessorWrapper;
|
use tk::processors::PostProcessorWrapper;
|
||||||
|
|
||||||
use crate::processors::PyPostProcessor;
|
use crate::processors::{PyPostProcessor, PyPostProcessorTypeWrapper};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn get_subtype() {
|
fn get_subtype() {
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let py_proc = PyPostProcessor::new(Arc::new(
|
let py_proc = PyPostProcessor::new(PyPostProcessorTypeWrapper::Single(Arc::new(
|
||||||
BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)).into(),
|
RwLock::new(BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)).into()),
|
||||||
));
|
)));
|
||||||
let py_bert = py_proc.get_as_subtype(py).unwrap();
|
let py_bert = py_proc.get_as_subtype(py).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
"BertProcessing",
|
"BertProcessing",
|
||||||
@ -510,21 +851,29 @@ mod test {
|
|||||||
let rs_processing_ser = serde_json::to_string(&rs_processing).unwrap();
|
let rs_processing_ser = serde_json::to_string(&rs_processing).unwrap();
|
||||||
let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap();
|
let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap();
|
||||||
|
|
||||||
let py_processing = PyPostProcessor::new(Arc::new(rs_wrapper));
|
let py_processing = PyPostProcessor::new(PyPostProcessorTypeWrapper::Single(Arc::new(
|
||||||
|
RwLock::new(rs_wrapper),
|
||||||
|
)));
|
||||||
let py_ser = serde_json::to_string(&py_processing).unwrap();
|
let py_ser = serde_json::to_string(&py_processing).unwrap();
|
||||||
assert_eq!(py_ser, rs_processing_ser);
|
assert_eq!(py_ser, rs_processing_ser);
|
||||||
assert_eq!(py_ser, rs_wrapper_ser);
|
assert_eq!(py_ser, rs_wrapper_ser);
|
||||||
|
|
||||||
let py_processing: PyPostProcessor = serde_json::from_str(&rs_processing_ser).unwrap();
|
let py_processing: PyPostProcessor = serde_json::from_str(&rs_processing_ser).unwrap();
|
||||||
match py_processing.processor.as_ref() {
|
match py_processing.processor {
|
||||||
PostProcessorWrapper::Bert(_) => (),
|
PyPostProcessorTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
|
||||||
_ => panic!("Expected Bert postprocessor."),
|
PostProcessorWrapper::Bert(_) => (),
|
||||||
|
_ => panic!("Expected Bert postprocessor."),
|
||||||
|
},
|
||||||
|
_ => panic!("Expected a single processor, got a sequence"),
|
||||||
}
|
}
|
||||||
|
|
||||||
let py_processing: PyPostProcessor = serde_json::from_str(&rs_wrapper_ser).unwrap();
|
let py_processing: PyPostProcessor = serde_json::from_str(&rs_wrapper_ser).unwrap();
|
||||||
match py_processing.processor.as_ref() {
|
match py_processing.processor {
|
||||||
PostProcessorWrapper::Bert(_) => (),
|
PyPostProcessorTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
|
||||||
_ => panic!("Expected Bert postprocessor."),
|
PostProcessorWrapper::Bert(_) => (),
|
||||||
}
|
_ => panic!("Expected Bert postprocessor."),
|
||||||
|
},
|
||||||
|
_ => panic!("Expected a single processor, got a sequence"),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -117,6 +117,12 @@ impl From<PySplitDelimiterBehavior> for SplitDelimiterBehavior {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<SplitDelimiterBehavior> for PySplitDelimiterBehavior {
|
||||||
|
fn from(v: SplitDelimiterBehavior) -> Self {
|
||||||
|
Self(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn filter(normalized: &mut NormalizedString, func: &Bound<'_, PyAny>) -> PyResult<()> {
|
fn filter(normalized: &mut NormalizedString, func: &Bound<'_, PyAny>) -> PyResult<()> {
|
||||||
let err = "`filter` expect a callable with the signature: `fn(char) -> bool`";
|
let err = "`filter` expect a callable with the signature: `fn(char) -> bool`";
|
||||||
|
|
||||||
|
@ -3,7 +3,16 @@ import pickle
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tokenizers import NormalizedString
|
from tokenizers import NormalizedString
|
||||||
from tokenizers.normalizers import BertNormalizer, Lowercase, Normalizer, Sequence, Strip, Prepend
|
from tokenizers.normalizers import (
|
||||||
|
BertNormalizer,
|
||||||
|
Lowercase,
|
||||||
|
Normalizer,
|
||||||
|
Precompiled,
|
||||||
|
Sequence,
|
||||||
|
Strip,
|
||||||
|
Prepend,
|
||||||
|
Replace,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestBertNormalizer:
|
class TestBertNormalizer:
|
||||||
@ -67,14 +76,57 @@ class TestSequence:
|
|||||||
output = normalizer.normalize_str(" HELLO ")
|
output = normalizer.normalize_str(" HELLO ")
|
||||||
assert output == "hello"
|
assert output == "hello"
|
||||||
|
|
||||||
def test_items(self):
|
def test_set_item(self):
|
||||||
normalizers = Sequence([BertNormalizer(True, True), Prepend(), Strip()])
|
normalizers = Sequence(
|
||||||
|
[
|
||||||
|
BertNormalizer(True, True),
|
||||||
|
Prepend(prepend="test"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert normalizers[0].__class__ == BertNormalizer
|
||||||
assert normalizers[1].__class__ == Prepend
|
assert normalizers[1].__class__ == Prepend
|
||||||
normalizers[0].lowercase = False
|
normalizers[1] = Strip()
|
||||||
assert not normalizers[0].lowercase
|
assert normalizers[1].__class__ == Strip
|
||||||
assert normalizers[2].__class__ == Strip
|
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
print(normalizers[3])
|
print(normalizers[2])
|
||||||
|
|
||||||
|
def test_item_getters_and_setters(self):
|
||||||
|
normalizers = Sequence(
|
||||||
|
[
|
||||||
|
BertNormalizer(clean_text=True, handle_chinese_chars=True, strip_accents=True, lowercase=True),
|
||||||
|
Strip(left=True, right=True),
|
||||||
|
Prepend(prepend="_"),
|
||||||
|
Replace(pattern="something", content="else"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert normalizers[0].__class__ == BertNormalizer
|
||||||
|
normalizers[0].clean_text = False
|
||||||
|
normalizers[0].handle_chinese_chars = False
|
||||||
|
normalizers[0].strip_accents = False
|
||||||
|
normalizers[0].lowercase = False
|
||||||
|
assert not normalizers[0].clean_text
|
||||||
|
assert not normalizers[0].handle_chinese_chars
|
||||||
|
assert not normalizers[0].strip_accents
|
||||||
|
assert not normalizers[0].lowercase
|
||||||
|
|
||||||
|
assert normalizers[1].__class__ == Strip
|
||||||
|
normalizers[1].left = False
|
||||||
|
normalizers[1].right = False
|
||||||
|
assert not normalizers[1].left
|
||||||
|
assert not normalizers[1].right
|
||||||
|
|
||||||
|
assert normalizers[2].__class__ == Prepend
|
||||||
|
normalizers[2].prepend = " "
|
||||||
|
assert normalizers[2].prepend == " "
|
||||||
|
|
||||||
|
assert normalizers[3].__class__ == Replace
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
normalizers[3].pattern = "test"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
print(normalizers[3].pattern)
|
||||||
|
normalizers[3].content = "test"
|
||||||
|
assert normalizers[3].content == "test"
|
||||||
|
|
||||||
|
|
||||||
class TestLowercase:
|
class TestLowercase:
|
||||||
|
@ -169,12 +169,69 @@ class TestSequence:
|
|||||||
("?", (29, 30)),
|
("?", (29, 30)),
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_items(self):
|
def test_set_item(self):
|
||||||
pre_tokenizers = Sequence([Metaspace("a", "never", split=True), Punctuation()])
|
pre_tokenizers = Sequence(
|
||||||
assert pre_tokenizers[1].__class__ == Punctuation
|
[
|
||||||
assert pre_tokenizers[0].__class__ == Metaspace
|
ByteLevel(),
|
||||||
pre_tokenizers[0].split = False
|
Split(pattern="/test/", behavior="removed"),
|
||||||
assert not pre_tokenizers[0].split
|
]
|
||||||
|
)
|
||||||
|
assert pre_tokenizers[0].__class__ == ByteLevel
|
||||||
|
assert pre_tokenizers[1].__class__ == Split
|
||||||
|
pre_tokenizers[1] = Metaspace()
|
||||||
|
assert pre_tokenizers[1].__class__ == Metaspace
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
print(pre_tokenizers[2])
|
||||||
|
|
||||||
|
def test_item_getters_and_setters(self):
|
||||||
|
pre_tokenizers = Sequence(
|
||||||
|
[
|
||||||
|
ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True),
|
||||||
|
Split(pattern="/test/", behavior="removed", invert=False),
|
||||||
|
Metaspace("a", "never", split=False),
|
||||||
|
CharDelimiterSplit(delimiter=" "),
|
||||||
|
Punctuation(behavior="removed"),
|
||||||
|
Digits(individual_digits=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert pre_tokenizers[0].__class__ == ByteLevel
|
||||||
|
pre_tokenizers[0].add_prefix_space = False
|
||||||
|
pre_tokenizers[0].trim_offsets = False
|
||||||
|
pre_tokenizers[0].use_regex = False
|
||||||
|
assert not pre_tokenizers[0].add_prefix_space
|
||||||
|
assert not pre_tokenizers[0].trim_offsets
|
||||||
|
assert not pre_tokenizers[0].use_regex
|
||||||
|
|
||||||
|
assert pre_tokenizers[1].__class__ == Split
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
pre_tokenizers[1].pattern = "/pattern/"
|
||||||
|
pre_tokenizers[1].behavior = "isolated"
|
||||||
|
pre_tokenizers[1].invert = True
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
pre_tokenizers[1].pattern
|
||||||
|
assert pre_tokenizers[1].behavior == "isolated"
|
||||||
|
assert pre_tokenizers[1].invert
|
||||||
|
|
||||||
|
assert pre_tokenizers[2].__class__ == Metaspace
|
||||||
|
pre_tokenizers[2].replacement = " "
|
||||||
|
pre_tokenizers[2].prepend_scheme = "always"
|
||||||
|
pre_tokenizers[2].split = True
|
||||||
|
assert pre_tokenizers[2].replacement == " "
|
||||||
|
assert pre_tokenizers[2].prepend_scheme == "always"
|
||||||
|
assert pre_tokenizers[2].split
|
||||||
|
|
||||||
|
assert pre_tokenizers[3].__class__ == CharDelimiterSplit
|
||||||
|
pre_tokenizers[3].delimiter = "_"
|
||||||
|
assert pre_tokenizers[3].delimiter == "_"
|
||||||
|
|
||||||
|
assert pre_tokenizers[4].__class__ == Punctuation
|
||||||
|
pre_tokenizers[4].behavior = "isolated"
|
||||||
|
assert pre_tokenizers[4].behavior == "isolated"
|
||||||
|
|
||||||
|
assert pre_tokenizers[5].__class__ == Digits
|
||||||
|
pre_tokenizers[5].individual_digits = False
|
||||||
|
assert not pre_tokenizers[5].individual_digits
|
||||||
|
|
||||||
|
|
||||||
class TestDigits:
|
class TestDigits:
|
||||||
|
@ -227,3 +227,18 @@ class TestSequenceProcessing:
|
|||||||
# assert pair.ids == [1, 2, 3, 4, 5, 0, 6, 0]
|
# assert pair.ids == [1, 2, 3, 4, 5, 0, 6, 0]
|
||||||
assert pair.type_ids == [0, 0, 0, 0, 0, 0, 1, 1]
|
assert pair.type_ids == [0, 0, 0, 0, 0, 0, 1, 1]
|
||||||
assert pair.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (12, 16), (0, 0), (0, 4), (0, 0)]
|
assert pair.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (12, 16), (0, 0), (0, 4), (0, 0)]
|
||||||
|
|
||||||
|
def test_items(self):
|
||||||
|
processors = Sequence([RobertaProcessing(("</s>", 1), ("<s>", 0)), ByteLevel()])
|
||||||
|
assert processors[0].__class__ == RobertaProcessing
|
||||||
|
assert processors[1].__class__ == ByteLevel
|
||||||
|
processors[0] = ByteLevel(add_prefix_space=False, trim_offsets=False, use_regex=False)
|
||||||
|
print(processors[0])
|
||||||
|
processors[0].add_prefix_space = True
|
||||||
|
processors[0].trim_offsets = True
|
||||||
|
processors[0].use_regex = True
|
||||||
|
print(processors[0])
|
||||||
|
assert processors[0].__class__ == ByteLevel
|
||||||
|
assert processors[0].add_prefix_space
|
||||||
|
assert processors[0].trim_offsets
|
||||||
|
assert processors[0].use_regex
|
||||||
|
@ -43,14 +43,14 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
|
|||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
{
|
{
|
||||||
#[derive(Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct Tagged {
|
pub struct Tagged {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
variant: EnumType,
|
variant: EnumType,
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
rest: serde_json::Value,
|
rest: serde_json::Value,
|
||||||
}
|
}
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub enum EnumType {
|
pub enum EnumType {
|
||||||
Bert,
|
Bert,
|
||||||
Strip,
|
Strip,
|
||||||
@ -168,7 +168,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
|
|||||||
NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe),
|
NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe),
|
||||||
NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe),
|
NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe),
|
||||||
NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe),
|
NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe),
|
||||||
NormalizerUntagged::Sequence(bpe) => NormalizerWrapper::Sequence(bpe),
|
NormalizerUntagged::Sequence(seq) => NormalizerWrapper::Sequence(seq),
|
||||||
NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe),
|
NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe),
|
||||||
NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe),
|
NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe),
|
||||||
NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe),
|
NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe),
|
||||||
|
@ -46,7 +46,7 @@ impl std::convert::TryFrom<ReplaceDeserializer> for Replace {
|
|||||||
#[serde(tag = "type", try_from = "ReplaceDeserializer")]
|
#[serde(tag = "type", try_from = "ReplaceDeserializer")]
|
||||||
pub struct Replace {
|
pub struct Replace {
|
||||||
pattern: ReplacePattern,
|
pattern: ReplacePattern,
|
||||||
content: String,
|
pub content: String,
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
regex: SysRegex,
|
regex: SysRegex,
|
||||||
}
|
}
|
||||||
|
@ -16,16 +16,29 @@ impl Sequence {
|
|||||||
pub fn new(normalizers: Vec<NormalizerWrapper>) -> Self {
|
pub fn new(normalizers: Vec<NormalizerWrapper>) -> Self {
|
||||||
Self { normalizers }
|
Self { normalizers }
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_normalizers(&self) -> &[NormalizerWrapper] {
|
impl AsRef<[NormalizerWrapper]> for Sequence {
|
||||||
|
fn as_ref(&self) -> &[NormalizerWrapper] {
|
||||||
&self.normalizers
|
&self.normalizers
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_normalizers_mut(&mut self) -> &mut [NormalizerWrapper] {
|
impl AsMut<[NormalizerWrapper]> for Sequence {
|
||||||
|
fn as_mut(&mut self) -> &mut [NormalizerWrapper] {
|
||||||
&mut self.normalizers
|
&mut self.normalizers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for Sequence {
|
||||||
|
type Item = NormalizerWrapper;
|
||||||
|
type IntoIter = std::vec::IntoIter<Self::Item>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
self.normalizers.into_iter()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Normalizer for Sequence {
|
impl Normalizer for Sequence {
|
||||||
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
|
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
|
||||||
for normalizer in &self.normalizers {
|
for normalizer in &self.normalizers {
|
||||||
|
@ -13,6 +13,12 @@ pub enum PrependScheme {
|
|||||||
Always,
|
Always,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for PrependScheme {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
self.serialize(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Eq)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Eq)]
|
||||||
/// Replaces all the whitespaces by the provided meta character and then
|
/// Replaces all the whitespaces by the provided meta character and then
|
||||||
/// splits on this character
|
/// splits on this character
|
||||||
|
@ -12,7 +12,7 @@ fn is_punc(x: char) -> bool {
|
|||||||
#[macro_rules_attribute(impl_serde_type!)]
|
#[macro_rules_attribute(impl_serde_type!)]
|
||||||
pub struct Punctuation {
|
pub struct Punctuation {
|
||||||
#[serde(default = "default_split")]
|
#[serde(default = "default_split")]
|
||||||
behavior: SplitDelimiterBehavior,
|
pub behavior: SplitDelimiterBehavior,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_split() -> SplitDelimiterBehavior {
|
fn default_split() -> SplitDelimiterBehavior {
|
||||||
|
@ -13,16 +13,29 @@ impl Sequence {
|
|||||||
pub fn new(pretokenizers: Vec<PreTokenizerWrapper>) -> Self {
|
pub fn new(pretokenizers: Vec<PreTokenizerWrapper>) -> Self {
|
||||||
Self { pretokenizers }
|
Self { pretokenizers }
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_pre_tokenizers(&self) -> &[PreTokenizerWrapper] {
|
impl AsRef<[PreTokenizerWrapper]> for Sequence {
|
||||||
|
fn as_ref(&self) -> &[PreTokenizerWrapper] {
|
||||||
&self.pretokenizers
|
&self.pretokenizers
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_pre_tokenizers_mut(&mut self) -> &mut [PreTokenizerWrapper] {
|
impl AsMut<[PreTokenizerWrapper]> for Sequence {
|
||||||
|
fn as_mut(&mut self) -> &mut [PreTokenizerWrapper] {
|
||||||
&mut self.pretokenizers
|
&mut self.pretokenizers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for Sequence {
|
||||||
|
type Item = PreTokenizerWrapper;
|
||||||
|
type IntoIter = std::vec::IntoIter<Self::Item>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
self.pretokenizers.into_iter()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl PreTokenizer for Sequence {
|
impl PreTokenizer for Sequence {
|
||||||
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
|
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
|
||||||
for pretokenizer in &self.pretokenizers {
|
for pretokenizer in &self.pretokenizers {
|
||||||
|
@ -27,11 +27,11 @@ impl From<&str> for SplitPattern {
|
|||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub struct Split {
|
pub struct Split {
|
||||||
pattern: SplitPattern,
|
pub pattern: SplitPattern,
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
regex: SysRegex,
|
pub regex: SysRegex,
|
||||||
behavior: SplitDelimiterBehavior,
|
pub behavior: SplitDelimiterBehavior,
|
||||||
invert: bool,
|
pub invert: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'de> Deserialize<'de> for Split {
|
impl<'de> Deserialize<'de> for Split {
|
||||||
|
@ -6,8 +6,8 @@ use std::iter::FromIterator;
|
|||||||
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
|
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub struct BertProcessing {
|
pub struct BertProcessing {
|
||||||
sep: (String, u32),
|
pub sep: (String, u32),
|
||||||
cls: (String, u32),
|
pub cls: (String, u32),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for BertProcessing {
|
impl Default for BertProcessing {
|
||||||
@ -23,6 +23,14 @@ impl BertProcessing {
|
|||||||
pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
|
pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
|
||||||
Self { sep, cls }
|
Self { sep, cls }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_sep_copy(&self) -> (String, u32) {
|
||||||
|
(self.sep.0.clone(), self.sep.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_cls_copy(&self) -> (String, u32) {
|
||||||
|
(self.cls.0.clone(), self.cls.1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
@ -7,10 +7,10 @@ use std::iter::FromIterator;
|
|||||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub struct RobertaProcessing {
|
pub struct RobertaProcessing {
|
||||||
sep: (String, u32),
|
pub sep: (String, u32),
|
||||||
cls: (String, u32),
|
pub cls: (String, u32),
|
||||||
trim_offsets: bool,
|
pub trim_offsets: bool,
|
||||||
add_prefix_space: bool,
|
pub add_prefix_space: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RobertaProcessing {
|
impl Default for RobertaProcessing {
|
||||||
@ -44,6 +44,14 @@ impl RobertaProcessing {
|
|||||||
self.add_prefix_space = v;
|
self.add_prefix_space = v;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_sep_copy(&self) -> (String, u32) {
|
||||||
|
(self.sep.0.clone(), self.sep.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_cls_copy(&self) -> (String, u32) {
|
||||||
|
(self.cls.0.clone(), self.cls.1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PostProcessor for RobertaProcessing {
|
impl PostProcessor for RobertaProcessing {
|
||||||
|
@ -13,6 +13,39 @@ impl Sequence {
|
|||||||
pub fn new(processors: Vec<PostProcessorWrapper>) -> Self {
|
pub fn new(processors: Vec<PostProcessorWrapper>) -> Self {
|
||||||
Self { processors }
|
Self { processors }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, index: usize) -> Option<&PostProcessorWrapper> {
|
||||||
|
self.processors.get(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_mut(&mut self, index: usize) -> Option<&mut PostProcessorWrapper> {
|
||||||
|
self.processors.get_mut(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_mut(&mut self, index: usize, post_proc: PostProcessorWrapper) {
|
||||||
|
self.processors[index] = post_proc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsRef<[PostProcessorWrapper]> for Sequence {
|
||||||
|
fn as_ref(&self) -> &[PostProcessorWrapper] {
|
||||||
|
&self.processors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsMut<[PostProcessorWrapper]> for Sequence {
|
||||||
|
fn as_mut(&mut self) -> &mut [PostProcessorWrapper] {
|
||||||
|
&mut self.processors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for Sequence {
|
||||||
|
type Item = PostProcessorWrapper;
|
||||||
|
type IntoIter = std::vec::IntoIter<Self::Item>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
self.processors.into_iter()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PostProcessor for Sequence {
|
impl PostProcessor for Sequence {
|
||||||
|
@ -338,7 +338,7 @@ impl From<HashMap<String, SpecialToken>> for Tokens {
|
|||||||
#[builder(build_fn(validate = "Self::validate"))]
|
#[builder(build_fn(validate = "Self::validate"))]
|
||||||
pub struct TemplateProcessing {
|
pub struct TemplateProcessing {
|
||||||
#[builder(try_setter, default = "\"$0\".try_into().unwrap()")]
|
#[builder(try_setter, default = "\"$0\".try_into().unwrap()")]
|
||||||
single: Template,
|
pub single: Template,
|
||||||
#[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")]
|
#[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")]
|
||||||
pair: Template,
|
pair: Template,
|
||||||
#[builder(setter(skip), default = "self.default_added(true)")]
|
#[builder(setter(skip), default = "self.default_added(true)")]
|
||||||
@ -351,6 +351,58 @@ pub struct TemplateProcessing {
|
|||||||
special_tokens: Tokens,
|
special_tokens: Tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TemplateProcessing {
|
||||||
|
// Getter for `single`
|
||||||
|
pub fn get_single(&self) -> String {
|
||||||
|
format!("{:?}", self.single)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setter for `single`
|
||||||
|
pub fn set_single(&mut self, single: Template) {
|
||||||
|
self.single = single;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getter for `pair`
|
||||||
|
pub fn get_pair(&self) -> &Template {
|
||||||
|
&self.pair
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setter for `pair`
|
||||||
|
pub fn set_pair(&mut self, pair: Template) {
|
||||||
|
self.pair = pair;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getter for `added_single`
|
||||||
|
pub fn get_added_single(&self) -> usize {
|
||||||
|
self.added_single
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setter for `added_single`
|
||||||
|
pub fn set_added_single(&mut self, added_single: usize) {
|
||||||
|
self.added_single = added_single;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getter for `added_pair`
|
||||||
|
pub fn get_added_pair(&self) -> usize {
|
||||||
|
self.added_pair
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setter for `added_pair`
|
||||||
|
pub fn set_added_pair(&mut self, added_pair: usize) {
|
||||||
|
self.added_pair = added_pair;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getter for `special_tokens`
|
||||||
|
pub fn get_special_tokens(&self) -> &Tokens {
|
||||||
|
&self.special_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setter for `special_tokens`
|
||||||
|
pub fn set_special_tokens(&mut self, special_tokens: Tokens) {
|
||||||
|
self.special_tokens = special_tokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<&str> for TemplateProcessingBuilderError {
|
impl From<&str> for TemplateProcessingBuilderError {
|
||||||
fn from(e: &str) -> Self {
|
fn from(e: &str) -> Self {
|
||||||
e.to_string().into()
|
e.to_string().into()
|
||||||
|
@ -87,6 +87,12 @@ pub enum SplitDelimiterBehavior {
|
|||||||
Contiguous,
|
Contiguous,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for SplitDelimiterBehavior {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
self.serialize(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// A `NormalizedString` takes care of processing an "original" string to modify
|
/// A `NormalizedString` takes care of processing an "original" string to modify
|
||||||
/// it and obtain a "normalized" string. It keeps both version of the string,
|
/// it and obtain a "normalized" string. It keeps both version of the string,
|
||||||
/// alignments information between both and provides an interface to retrieve
|
/// alignments information between both and provides an interface to retrieve
|
||||||
|
Reference in New Issue
Block a user