Adding a 3 new PreTokenizers:

- Deduplication : Removes duplicate spaces within strings
- Punctuation: Splits punctuation characters as isolated tokens
- Sequence: Applies a list of pretokenizers iteratively
This commit is contained in:
Nicolas Patry
2020-08-21 16:37:38 +02:00
committed by Anthony MOI
parent 50ac90d338
commit 7ed7f0f26a
9 changed files with 341 additions and 4 deletions

View File

@@ -3,6 +3,9 @@ from .. import pre_tokenizers
PreTokenizer = pre_tokenizers.PreTokenizer
ByteLevel = pre_tokenizers.ByteLevel
Whitespace = pre_tokenizers.Whitespace
Deduplication = pre_tokenizers.Deduplication
Punctuation = pre_tokenizers.Punctuation
Sequence = pre_tokenizers.Sequence
WhitespaceSplit = pre_tokenizers.WhitespaceSplit
BertPreTokenizer = pre_tokenizers.BertPreTokenizer
Metaspace = pre_tokenizers.Metaspace

View File

@@ -107,3 +107,34 @@ class CharDelimiterSplit(PreTokenizer):
The delimiter char that will be used to split input
"""
pass
class Deduplication(PreTokenizer):
""" Deduplication PreTokenizer
This pre-tokenizer simply splits using the following regex: `\w+|[^\w\s]+`
"""
def __init__(self) -> None:
""" Instantiate a new Deduplication PreTokenizer """
pass
class Punctuation(PreTokenizer):
""" Punctuation PreTokenizer
This pre-tokenizer simply splits using the following regex: `\w+|[^\w\s]+`
"""
def __init__(self) -> None:
""" Instantiate a new Punctuation PreTokenizer """
pass
class Sequence(PreTokenizer):
""" Sequence PreTokenizer
This pre-tokenizer simply splits using the following regex: `\w+|[^\w\s]+`
"""
def __init__(self) -> None:
""" Instantiate a new Sequence PreTokenizer """
pass

View File

@@ -64,6 +64,9 @@ fn pre_tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<pre_tokenizers::PyBertPreTokenizer>()?;
m.add_class::<pre_tokenizers::PyMetaspace>()?;
m.add_class::<pre_tokenizers::PyCharDelimiterSplit>()?;
m.add_class::<pre_tokenizers::PyDeduplication>()?;
m.add_class::<pre_tokenizers::PyPunctuation>()?;
m.add_class::<pre_tokenizers::PySequence>()?;
Ok(())
}

View File

@@ -3,12 +3,16 @@ use std::sync::Arc;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::pre_tokenizers::bert::BertPreTokenizer;
use tk::pre_tokenizers::byte_level::ByteLevel;
use tk::pre_tokenizers::deduplication::Deduplication;
use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
use tk::pre_tokenizers::metaspace::Metaspace;
use tk::pre_tokenizers::punctuation::Punctuation;
// use tk::pre_tokenizers::sequence::Sequence;
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
use tk::pre_tokenizers::PreTokenizerWrapper;
use tk::tokenizer::Offsets;
@@ -36,10 +40,22 @@ impl PyPreTokenizer {
let py = gil.python();
match &self.pretok {
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base).map(Into::into),
PyPreTokenizerWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base)).map(Into::into)
}
PyPreTokenizerWrapper::Wrapped(inner) => match inner.as_ref() {
PreTokenizerWrapper::Whitespace(_) => {
Py::new(py, (PyWhitespace {}, base)).map(Into::into)
}
PreTokenizerWrapper::Deduplication(_) => {
Py::new(py, (PyDeduplication {}, base)).map(Into::into)
}
PreTokenizerWrapper::Punctuation(_) => {
Py::new(py, (PyPunctuation {}, base)).map(Into::into)
}
PreTokenizerWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base)).map(Into::into)
}
PreTokenizerWrapper::Metaspace(_) => {
Py::new(py, (PyMetaspace {}, base)).map(Into::into)
}
@@ -201,6 +217,56 @@ impl PyBertPreTokenizer {
}
}
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Deduplication)]
pub struct PyDeduplication {}
#[pymethods]
impl PyDeduplication {
#[new]
fn new() -> PyResult<(Self, PyPreTokenizer)> {
Ok((PyDeduplication {}, Deduplication.into()))
}
}
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Punctuation)]
pub struct PyPunctuation {}
#[pymethods]
impl PyPunctuation {
#[new]
fn new() -> PyResult<(Self, PyPreTokenizer)> {
Ok((PyPunctuation {}, Punctuation.into()))
}
}
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Sequence)]
pub struct PySequence {}
#[pymethods]
impl PySequence {
#[new]
fn new(pre_tokenizers: &PyList) -> PyResult<(Self, PyPreTokenizer)> {
let mut sequence = Vec::with_capacity(pre_tokenizers.len());
for n in pre_tokenizers.iter() {
let pretokenizer: PyRef<PyPreTokenizer> = n.extract()?;
match &pretokenizer.pretok {
PyPreTokenizerWrapper::Sequence(inner) => {
sequence.extend(inner.iter().map(|i| i.clone()))
}
PyPreTokenizerWrapper::Wrapped(inner) => sequence.push(inner.clone()),
PyPreTokenizerWrapper::Custom(_) => unreachable!(
"Custom pretokenizers are currently disabled, how did you get here?"
),
}
}
Ok((
PySequence {},
PyPreTokenizer::new(PyPreTokenizerWrapper::Sequence(sequence)),
))
}
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
Ok(PyTuple::new(py, &[PyList::empty(py)]))
}
}
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Metaspace)]
pub struct PyMetaspace {}
#[pymethods]
@@ -295,13 +361,34 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer {
}
}
#[derive(Clone, Deserialize, Serialize)]
#[derive(Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum PyPreTokenizerWrapper {
Sequence(Vec<Arc<PreTokenizerWrapper>>),
Custom(Arc<CustomPreTokenizer>),
Wrapped(Arc<PreTokenizerWrapper>),
}
impl Serialize for PyPreTokenizerWrapper {
fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
where
S: Serializer,
{
match self {
PyPreTokenizerWrapper::Sequence(seq) => {
let mut ser = serializer.serialize_struct("Sequence", 2)?;
ser.serialize_field("type", "Sequence")?;
ser.serialize_field("pretokenizers", seq)?;
ser.end()
}
PyPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
PyPreTokenizerWrapper::Custom(_) => {
unreachable!("Custom pretokenizers are currently disabled, how did you get here?")
}
}
}
}
impl<I> From<I> for PyPreTokenizerWrapper
where
I: Into<PreTokenizerWrapper>,
@@ -326,6 +413,9 @@ impl PreTokenizer for PyPreTokenizerWrapper {
fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> {
match self {
PyPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(normalized),
PyPreTokenizerWrapper::Sequence(inner) => {
inner.iter().map(|n| n.pre_tokenize(normalized)).collect()
}
PyPreTokenizerWrapper::Custom(_) => {
unreachable!("Custom pretokenizers are currently disabled, how did you get here?")
}

View File

@@ -9,6 +9,9 @@ from tokenizers.pre_tokenizers import (
BertPreTokenizer,
Metaspace,
CharDelimiterSplit,
Deduplication,
Punctuation,
Sequence,
)
@@ -39,7 +42,9 @@ class TestWhitespaceSplit:
assert WhitespaceSplit() is not None
assert isinstance(WhitespaceSplit(), PreTokenizer)
assert isinstance(WhitespaceSplit(), WhitespaceSplit)
assert isinstance(pickle.loads(pickle.dumps(WhitespaceSplit())), WhitespaceSplit)
assert isinstance(
pickle.loads(pickle.dumps(WhitespaceSplit())), WhitespaceSplit
)
class TestBertPreTokenizer:
@@ -47,7 +52,9 @@ class TestBertPreTokenizer:
assert BertPreTokenizer() is not None
assert isinstance(BertPreTokenizer(), PreTokenizer)
assert isinstance(BertPreTokenizer(), BertPreTokenizer)
assert isinstance(pickle.loads(pickle.dumps(BertPreTokenizer())), BertPreTokenizer)
assert isinstance(
pickle.loads(pickle.dumps(BertPreTokenizer())), BertPreTokenizer
)
class TestMetaspace:
@@ -69,4 +76,50 @@ class TestCharDelimiterSplit:
CharDelimiterSplit("")
assert isinstance(CharDelimiterSplit(" "), PreTokenizer)
assert isinstance(CharDelimiterSplit(" "), CharDelimiterSplit)
assert isinstance(pickle.loads(pickle.dumps(CharDelimiterSplit("-"))), CharDelimiterSplit)
assert isinstance(
pickle.loads(pickle.dumps(CharDelimiterSplit("-"))), CharDelimiterSplit
)
class TestDeduplication:
def test_instantiate(self):
assert Deduplication() is not None
assert isinstance(Deduplication(), PreTokenizer)
assert isinstance(Deduplication(), Deduplication)
assert isinstance(pickle.loads(pickle.dumps(Deduplication())), Deduplication)
class TestPunctuation:
def test_instantiate(self):
assert Punctuation() is not None
assert isinstance(Punctuation(), PreTokenizer)
assert isinstance(Punctuation(), Punctuation)
assert isinstance(pickle.loads(pickle.dumps(Punctuation())), Punctuation)
class TestSequence:
def test_instantiate(self):
assert Sequence([]) is not None
assert isinstance(Sequence([]), PreTokenizer)
assert isinstance(Sequence([]), Sequence)
dumped = pickle.dumps(Sequence([]))
assert isinstance(pickle.loads(dumped), Sequence)
def test_bert_like(self):
pre_tokenizer = Sequence([Deduplication(), Punctuation()])
assert isinstance(Sequence([]), PreTokenizer)
assert isinstance(Sequence([]), Sequence)
assert isinstance(pickle.loads(pickle.dumps(pre_tokenizer)), Sequence)
result = pre_tokenizer.pre_tokenize("Hey friend! How are you?!?")
assert result == [
("Hey", (0, 3)),
("friend", (4, 10)),
("!", (10, 11)),
("How", (16, 19)),
("are", (20, 23)),
("you", (24, 27)),
("?", (27, 28)),
("!", (28, 29)),
("?", (29, 30)),
]

View File

@@ -0,0 +1,38 @@
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
#[derive(Copy, Clone, Debug)]
pub struct Deduplication;
impl_serde_unit_struct!(DeduplicationVisitor, Deduplication);
impl PreTokenizer for Deduplication {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
pretokenized.split(|_, s| s.split(char::is_whitespace, SplitDelimiterBehavior::Removed))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OffsetReferential;
#[test]
fn deduplication_basic() {
let pretok = Deduplication;
let mut pretokenized: PreTokenizedString = "Hey friend! How are you?!?".into();
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("Hey", (0, 3)),
("friend!", (4, 11)),
("How", (16, 19)),
("are", (20, 23)),
("you?!?", (24, 30)),
]
);
}
}

View File

@@ -1,15 +1,21 @@
pub mod bert;
pub mod byte_level;
pub mod deduplication;
pub mod delimiter;
pub mod metaspace;
pub mod punctuation;
pub mod sequence;
pub mod whitespace;
use serde::{Deserialize, Serialize};
use crate::pre_tokenizers::bert::BertPreTokenizer;
use crate::pre_tokenizers::byte_level::ByteLevel;
use crate::pre_tokenizers::deduplication::Deduplication;
use crate::pre_tokenizers::delimiter::CharDelimiterSplit;
use crate::pre_tokenizers::metaspace::Metaspace;
use crate::pre_tokenizers::punctuation::Punctuation;
use crate::pre_tokenizers::sequence::Sequence;
use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
use crate::{PreTokenizedString, PreTokenizer};
@@ -21,6 +27,9 @@ pub enum PreTokenizerWrapper {
Delimiter(CharDelimiterSplit),
Metaspace(Metaspace),
Whitespace(Whitespace),
Sequence(Sequence),
Deduplication(Deduplication),
Punctuation(Punctuation),
WhitespaceSplit(WhitespaceSplit),
}
@@ -32,6 +41,9 @@ impl PreTokenizer for PreTokenizerWrapper {
PreTokenizerWrapper::Delimiter(dpt) => dpt.pre_tokenize(normalized),
PreTokenizerWrapper::Metaspace(mspt) => mspt.pre_tokenize(normalized),
PreTokenizerWrapper::Whitespace(wspt) => wspt.pre_tokenize(normalized),
PreTokenizerWrapper::Deduplication(tok) => tok.pre_tokenize(normalized),
PreTokenizerWrapper::Punctuation(tok) => tok.pre_tokenize(normalized),
PreTokenizerWrapper::Sequence(tok) => tok.pre_tokenize(normalized),
PreTokenizerWrapper::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized),
}
}
@@ -41,5 +53,8 @@ impl_enum_from!(BertPreTokenizer, PreTokenizerWrapper, BertPreTokenizer);
impl_enum_from!(ByteLevel, PreTokenizerWrapper, ByteLevel);
impl_enum_from!(CharDelimiterSplit, PreTokenizerWrapper, Delimiter);
impl_enum_from!(Whitespace, PreTokenizerWrapper, Whitespace);
impl_enum_from!(Deduplication, PreTokenizerWrapper, Deduplication);
impl_enum_from!(Punctuation, PreTokenizerWrapper, Punctuation);
impl_enum_from!(Sequence, PreTokenizerWrapper, Sequence);
impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace);
impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit);

View File

@@ -0,0 +1,44 @@
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
use unicode_categories::UnicodeCategories;
fn is_bert_punc(x: char) -> bool {
char::is_ascii_punctuation(&x) || x.is_punctuation()
}
#[derive(Copy, Clone, Debug)]
pub struct Punctuation;
impl_serde_unit_struct!(PunctuationVisitor, Punctuation);
impl PreTokenizer for Punctuation {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
pretokenized.split(|_, s| s.split(is_bert_punc, SplitDelimiterBehavior::Isolated))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OffsetReferential;
#[test]
fn punctuation_basic() {
let pretok = Punctuation;
let mut pretokenized: PreTokenizedString = "Hey friend! How are you?!?".into();
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("Hey friend", (0, 10)),
("!", (10, 11)),
(" How are you", (11, 27)),
("?", (27, 28)),
("!", (28, 29)),
("?", (29, 30)),
]
);
}
}

View File

@@ -0,0 +1,60 @@
use crate::pre_tokenizers::PreTokenizerWrapper;
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub struct Sequence {
pretokenizers: Vec<PreTokenizerWrapper>,
}
impl Sequence {
pub fn new(pretokenizers: Vec<PreTokenizerWrapper>) -> Self {
Self { pretokenizers }
}
}
impl PreTokenizer for Sequence {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
for pretokenizer in &self.pretokenizers {
pretokenizer.pre_tokenize(pretokenized)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pre_tokenizers::{deduplication::Deduplication, punctuation::Punctuation};
use crate::OffsetReferential;
#[test]
fn sequence_basic() {
let pretokenizers = vec![
PreTokenizerWrapper::Deduplication(Deduplication),
PreTokenizerWrapper::Punctuation(Punctuation),
];
let pretok = Sequence::new(pretokenizers);
let mut pretokenized: PreTokenizedString = "Hey friend! How are you?!?".into();
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("Hey", (0, 3)),
("friend", (4, 10)),
("!", (10, 11)),
("How", (16, 19)),
("are", (20, 23)),
("you", (24, 27)),
("?", (27, 28)),
("!", (28, 29)),
("?", (29, 30)),
]
);
}
}