Adding a new pre_tokenizer: Digits.

Easier to split on digits:

Digits(individual_digits=False) -> 'Call 123 please' becomes 'Call ',
'123', 'please'
Digits(individual_digits=True) -> 'Call 123 please' becomes 'Call ',
'1', '2', '3', 'please'
This commit is contained in:
Nicolas Patry
2020-09-03 17:27:58 +02:00
parent b8f1eb48cb
commit 7b2caca764
10 changed files with 293 additions and 18 deletions

View File

@@ -91,7 +91,7 @@ dependencies = [
"ansi_term", "ansi_term",
"atty", "atty",
"bitflags", "bitflags",
"strsim", "strsim 0.8.0",
"textwrap", "textwrap",
"unicode-width", "unicode-width",
"vec_map", "vec_map",
@@ -168,6 +168,66 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "697c714f50560202b1f4e2e09cd50a421881c83e9025db75d15f276616f04f40" checksum = "697c714f50560202b1f4e2e09cd50a421881c83e9025db75d15f276616f04f40"
[[package]]
name = "darling"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d706e75d87e35569db781a9b5e2416cff1236a47ed380831f959382ccd5f858"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0c960ae2da4de88a91b2d920c2a7233b400bc33cb28453a2987822d8392519b"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim 0.9.3",
"syn",
]
[[package]]
name = "darling_macro"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72"
dependencies = [
"darling_core",
"quote",
"syn",
]
[[package]]
name = "derive_builder"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2658621297f2cf68762a6f7dc0bb7e1ff2cfd6583daef8ee0fed6f7ec468ec0"
dependencies = [
"darling",
"derive_builder_core",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "derive_builder_core"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2791ea3e372c8495c0bc2033991d76b512cd799d07491fbd6890124db9458bef"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "either" name = "either"
version = "1.5.3" version = "1.5.3"
@@ -190,6 +250,21 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "esaxx-rs"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3f0bf221d15f92461d05eea094c77aec5a00e3574740159e178beab2c58ea64"
dependencies = [
"cc",
]
[[package]]
name = "fnv"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.1.14" version = "0.1.14"
@@ -216,6 +291,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]] [[package]]
name = "indicatif" name = "indicatif"
version = "0.14.0" version = "0.14.0"
@@ -678,6 +759,12 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
[[package]]
name = "strsim"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.37" version = "1.0.37"
@@ -731,6 +818,8 @@ name = "tokenizers"
version = "0.10.1" version = "0.10.1"
dependencies = [ dependencies = [
"clap", "clap",
"derive_builder",
"esaxx-rs",
"indicatif", "indicatif",
"itertools 0.9.0", "itertools 0.9.0",
"lazy_static", "lazy_static",
@@ -744,6 +833,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"unicode-normalization-alignments", "unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories", "unicode_categories",
] ]
@@ -756,6 +846,12 @@ dependencies = [
"smallvec", "smallvec",
] ]
[[package]]
name = "unicode-segmentation"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e83e153d1053cbb5a118eeff7fd5be06ed99153f00dbcd8ae310c5fb2b22edc0"
[[package]] [[package]]
name = "unicode-width" name = "unicode-width"
version = "0.1.8" version = "0.1.8"

View File

@@ -842,7 +842,7 @@ dependencies = [
[[package]] [[package]]
name = "tokenizers-python" name = "tokenizers-python"
version = "0.9.0-dev0" version = "0.9.0-dev1"
dependencies = [ dependencies = [
"env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)", "env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.74 (registry+https://github.com/rust-lang/crates.io-index)", "libc 0.2.74 (registry+https://github.com/rust-lang/crates.io-index)",

View File

@@ -9,3 +9,4 @@ WhitespaceSplit = pre_tokenizers.WhitespaceSplit
BertPreTokenizer = pre_tokenizers.BertPreTokenizer BertPreTokenizer = pre_tokenizers.BertPreTokenizer
Metaspace = pre_tokenizers.Metaspace Metaspace = pre_tokenizers.Metaspace
CharDelimiterSplit = pre_tokenizers.CharDelimiterSplit CharDelimiterSplit = pre_tokenizers.CharDelimiterSplit
Digits = pre_tokenizers.Digits

View File

@@ -127,3 +127,20 @@ class Sequence(PreTokenizer):
def __init__(self) -> None: def __init__(self) -> None:
""" Instantiate a new Sequence PreTokenizer """ """ Instantiate a new Sequence PreTokenizer """
pass pass
class Digits(PreTokenizer):
"""Digits PreTokenizer
This pre-tokenizer simply splits using the digits in separate tokens
"""
def __init__(self, individual_digits: bool) -> None:
"""Instantiate a new Digits
Args:
individual_digits: bool:
If set to True, digits will each be separated "Call 123 please" -> "Call ", "1", "2", "3", " please"
If set to False, digits will grouped "Call 123 please" -> "Call ", "123", " please"
"""
pass

View File

@@ -68,6 +68,7 @@ fn pre_tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<pre_tokenizers::PyCharDelimiterSplit>()?; m.add_class::<pre_tokenizers::PyCharDelimiterSplit>()?;
m.add_class::<pre_tokenizers::PyPunctuation>()?; m.add_class::<pre_tokenizers::PyPunctuation>()?;
m.add_class::<pre_tokenizers::PySequence>()?; m.add_class::<pre_tokenizers::PySequence>()?;
m.add_class::<pre_tokenizers::PyDigits>()?;
Ok(()) Ok(())
} }

View File

@@ -9,6 +9,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::pre_tokenizers::bert::BertPreTokenizer; use tk::pre_tokenizers::bert::BertPreTokenizer;
use tk::pre_tokenizers::byte_level::ByteLevel; use tk::pre_tokenizers::byte_level::ByteLevel;
use tk::pre_tokenizers::delimiter::CharDelimiterSplit; use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
use tk::pre_tokenizers::digits::Digits;
use tk::pre_tokenizers::metaspace::Metaspace; use tk::pre_tokenizers::metaspace::Metaspace;
use tk::pre_tokenizers::punctuation::Punctuation; use tk::pre_tokenizers::punctuation::Punctuation;
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
@@ -66,6 +67,7 @@ impl PyPreTokenizer {
PreTokenizerWrapper::BertPreTokenizer(_) => { PreTokenizerWrapper::BertPreTokenizer(_) => {
Py::new(py, (PyBertPreTokenizer {}, base)).map(Into::into) Py::new(py, (PyBertPreTokenizer {}, base)).map(Into::into)
} }
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base)).map(Into::into),
}, },
} }
} }
@@ -281,6 +283,30 @@ impl PyMetaspace {
} }
} }
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Digits)]
pub struct PyDigits {}
#[pymethods]
impl PyDigits {
#[new]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyPreTokenizer)> {
let mut individual_digits = false;
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
match key {
"individual_digits" => {
individual_digits = value.extract()?;
}
_ => println!("Ignored unknown kwarg option {}", key),
}
}
}
Ok((PyDigits {}, Digits::new(individual_digits).into()))
}
}
// this is not accessible in python since the custom method is disabled. // this is not accessible in python since the custom method is disabled.
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) struct CustomPreTokenizer { pub(crate) struct CustomPreTokenizer {

View File

@@ -11,6 +11,7 @@ from tokenizers.pre_tokenizers import (
CharDelimiterSplit, CharDelimiterSplit,
Punctuation, Punctuation,
Sequence, Sequence,
Digits,
) )
@@ -108,3 +109,13 @@ class TestSequence:
("!", (28, 29)), ("!", (28, 29)),
("?", (29, 30)), ("?", (29, 30)),
] ]
class TestDigits:
def test_instantiate(self):
assert Digits() is not None
assert isinstance(Digits(), PreTokenizer)
assert isinstance(Digits(), Digits)
assert isinstance(Digits(individual_digits=True), Digits)
assert isinstance(Digits(individual_digits=False), Digits)
assert isinstance(pickle.loads(pickle.dumps(Digits())), Digits)

View File

@@ -0,0 +1,100 @@
use serde::{Deserialize, Serialize};
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
#[derive(Serialize, Deserialize, Clone, Debug)]
/// Replaces all the whitespaces by the provided meta character and then
/// splits on this character
#[serde(tag = "type")]
pub struct Digits {
individual_digits: bool,
}
impl Digits {
pub fn new(individual_digits: bool) -> Self {
Self { individual_digits }
}
}
impl Default for Digits {
fn default() -> Self {
Self::new(false)
}
}
impl PreTokenizer for Digits {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
if self.individual_digits {
pretokenized.split(|_, normalized| {
normalized.split(char::is_numeric, SplitDelimiterBehavior::Isolated)
})
} else {
pretokenized.split(|_, normalized| {
normalized.split(char::is_numeric, SplitDelimiterBehavior::Contiguous)
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OffsetReferential;
#[test]
fn numbers() {
let pretok = Digits::new(false);
let mut pretokenized = PreTokenizedString::from("Hey 123 friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![("Hey ", (0, 4)), ("123", (4, 7)), (" friend!", (7, 15))]
);
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![("Hey ", (0, 4)), ("123", (4, 7)), (" friend!", (7, 15))]
);
}
#[test]
fn individual_digits() {
let pretok = Digits::new(true);
let mut pretokenized = PreTokenizedString::from("Hey 123 friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("Hey ", (0, 4)),
("1", (4, 5)),
("2", (5, 6)),
("3", (6, 7)),
(" friend!", (7, 15))
]
);
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("Hey ", (0, 4)),
("1", (4, 5)),
("2", (5, 6)),
("3", (6, 7)),
(" friend!", (7, 15))
]
);
}
}

View File

@@ -1,6 +1,7 @@
pub mod bert; pub mod bert;
pub mod byte_level; pub mod byte_level;
pub mod delimiter; pub mod delimiter;
pub mod digits;
pub mod metaspace; pub mod metaspace;
pub mod punctuation; pub mod punctuation;
pub mod sequence; pub mod sequence;
@@ -11,6 +12,7 @@ use serde::{Deserialize, Serialize};
use crate::pre_tokenizers::bert::BertPreTokenizer; use crate::pre_tokenizers::bert::BertPreTokenizer;
use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::byte_level::ByteLevel;
use crate::pre_tokenizers::delimiter::CharDelimiterSplit; use crate::pre_tokenizers::delimiter::CharDelimiterSplit;
use crate::pre_tokenizers::digits::Digits;
use crate::pre_tokenizers::metaspace::Metaspace; use crate::pre_tokenizers::metaspace::Metaspace;
use crate::pre_tokenizers::punctuation::Punctuation; use crate::pre_tokenizers::punctuation::Punctuation;
use crate::pre_tokenizers::sequence::Sequence; use crate::pre_tokenizers::sequence::Sequence;
@@ -28,6 +30,7 @@ pub enum PreTokenizerWrapper {
Sequence(Sequence), Sequence(Sequence),
Punctuation(Punctuation), Punctuation(Punctuation),
WhitespaceSplit(WhitespaceSplit), WhitespaceSplit(WhitespaceSplit),
Digits(Digits),
} }
impl PreTokenizer for PreTokenizerWrapper { impl PreTokenizer for PreTokenizerWrapper {
@@ -41,6 +44,7 @@ impl PreTokenizer for PreTokenizerWrapper {
PreTokenizerWrapper::Punctuation(tok) => tok.pre_tokenize(normalized), PreTokenizerWrapper::Punctuation(tok) => tok.pre_tokenize(normalized),
PreTokenizerWrapper::Sequence(tok) => tok.pre_tokenize(normalized), PreTokenizerWrapper::Sequence(tok) => tok.pre_tokenize(normalized),
PreTokenizerWrapper::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized), PreTokenizerWrapper::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized),
PreTokenizerWrapper::Digits(wspt) => wspt.pre_tokenize(normalized),
} }
} }
} }
@@ -53,3 +57,4 @@ impl_enum_from!(Punctuation, PreTokenizerWrapper, Punctuation);
impl_enum_from!(Sequence, PreTokenizerWrapper, Sequence); impl_enum_from!(Sequence, PreTokenizerWrapper, Sequence);
impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace); impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace);
impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit); impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit);
impl_enum_from!(Digits, PreTokenizerWrapper, Digits);

View File

@@ -1,5 +1,3 @@
#![allow(clippy::reversed_empty_ranges)]
use crate::pattern::Pattern; use crate::pattern::Pattern;
use crate::{Offsets, Result}; use crate::{Offsets, Result};
use std::ops::{Bound, RangeBounds}; use std::ops::{Bound, RangeBounds};
@@ -89,11 +87,13 @@ where
/// - Isolated => `[ "the", "-", "final", "-", "-", "countdown" ]` /// - Isolated => `[ "the", "-", "final", "-", "-", "countdown" ]`
/// - MergedWithPrevious => `[ "the-", "final-", "-", "countdown" ]` /// - MergedWithPrevious => `[ "the-", "final-", "-", "countdown" ]`
/// - MergedWithNext => `[ "the", "-final", "-", "-countdown" ]` /// - MergedWithNext => `[ "the", "-final", "-", "-countdown" ]`
/// - Contiguous => `[ "the", "-", "final", "--", "countdown" ]`
pub enum SplitDelimiterBehavior { pub enum SplitDelimiterBehavior {
Removed, Removed,
Isolated, Isolated,
MergedWithPrevious, MergedWithPrevious,
MergedWithNext, MergedWithNext,
Contiguous,
} }
/// A `NormalizedString` takes care of processing an "original" string to modify /// A `NormalizedString` takes care of processing an "original" string to modify
@@ -784,6 +784,24 @@ impl NormalizedString {
.map(|(offsets, _)| (offsets, false)) .map(|(offsets, _)| (offsets, false))
.collect(), .collect(),
Removed => matches, Removed => matches,
Contiguous => {
let mut previous_match = false;
matches
.into_iter()
.fold(vec![], |mut acc, (offsets, is_match)| {
if is_match == previous_match {
if let Some(((_, end), _)) = acc.last_mut() {
*end = offsets.1;
} else {
acc.push((offsets, false));
}
} else {
acc.push((offsets, false));
}
previous_match = is_match;
acc
})
}
MergedWithPrevious => { MergedWithPrevious => {
let mut previous_match = false; let mut previous_match = false;
matches matches
@@ -1038,7 +1056,6 @@ impl From<&str> for NormalizedString {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#![allow(clippy::reversed_empty_ranges)]
use super::*; use super::*;
use regex::Regex; use regex::Regex;
use unicode_categories::UnicodeCategories; use unicode_categories::UnicodeCategories;
@@ -1489,6 +1506,7 @@ mod tests {
test(Isolated, vec!["The", "-", "final", "-", "-", "countdown"]); test(Isolated, vec!["The", "-", "final", "-", "-", "countdown"]);
test(MergedWithPrevious, vec!["The-", "final-", "-", "countdown"]); test(MergedWithPrevious, vec!["The-", "final-", "-", "countdown"]);
test(MergedWithNext, vec!["The", "-final", "-", "-countdown"]); test(MergedWithNext, vec!["The", "-final", "-", "-countdown"]);
test(Contiguous, vec!["The", "-", "final", "--", "countdown"]);
} }
#[test] #[test]