diff --git a/bindings/node/native/src/pre_tokenizers.rs b/bindings/node/native/src/pre_tokenizers.rs index c350bd04..65bbc676 100644 --- a/bindings/node/native/src/pre_tokenizers.rs +++ b/bindings/node/native/src/pre_tokenizers.rs @@ -125,12 +125,15 @@ declare_types! { } } -/// byte_level(addPrefixSpace: bool = true) +/// byte_level(addPrefixSpace: bool = true, useRegex: bool = true) fn byte_level(mut cx: FunctionContext) -> JsResult { let mut byte_level = tk::pre_tokenizers::byte_level::ByteLevel::default(); if let Some(add_prefix_space) = cx.extract_opt::(0)? { byte_level = byte_level.add_prefix_space(add_prefix_space); } + if let Some(use_regex) = cx.extract_opt::(1)? { + byte_level = byte_level.use_regex(use_regex); + } let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let guard = cx.lock(); diff --git a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi index e3ffbbbc..b149ee17 100644 --- a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi @@ -102,7 +102,7 @@ class ByteLevel(PreTokenizer): lets us treat `hello` exactly like `say hello`. """ - def __init__(self, add_prefix_space=True): + def __init__(self, add_prefix_space=True, use_regex=True): pass @staticmethod def alphabet(): diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 947e267e..5edc6f2e 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -229,7 +229,7 @@ macro_rules! setter { /// Whether to add a space to the first word if there isn't already one. This /// lets us treat `hello` exactly like `say hello`. #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=ByteLevel)] -#[text_signature = "(self, add_prefix_space=True)"] +#[text_signature = "(self, add_prefix_space=True, use_regex=True)"] pub struct PyByteLevel {} #[pymethods] impl PyByteLevel { @@ -243,13 +243,28 @@ impl PyByteLevel { setter!(self_, ByteLevel, add_prefix_space, add_prefix_space); } + #[getter] + fn get_use_regex(self_: PyRef) -> bool { + getter!(self_, ByteLevel, use_regex) + } + + #[setter] + fn set_use_regex(self_: PyRef, use_regex: bool) { + setter!(self_, ByteLevel, use_regex, use_regex); + } + #[new] - #[args(add_prefix_space = "true", _kwargs = "**")] - fn new(add_prefix_space: bool, _kwargs: Option<&PyDict>) -> (Self, PyPreTokenizer) { + #[args(add_prefix_space = "true", use_regex = "true", _kwargs = "**")] + fn new( + add_prefix_space: bool, + use_regex: bool, + _kwargs: Option<&PyDict>, + ) -> (Self, PyPreTokenizer) { ( PyByteLevel {}, ByteLevel::default() .add_prefix_space(add_prefix_space) + .use_regex(use_regex) .into(), ) } diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 27c2e14b..4bb9c07c 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -53,6 +53,15 @@ pub struct ByteLevel { pub add_prefix_space: bool, /// Whether the post processing step should trim offsets to avoid including whitespaces. pub trim_offsets: bool, + + /// Whether to use the standard GPT2 regex for whitespace splitting + /// Set it to False if you want to use your own splitting. + #[serde(default = "default_true")] + pub use_regex: bool, +} + +fn default_true() -> bool { + true } impl Default for ByteLevel { @@ -60,15 +69,17 @@ impl Default for ByteLevel { Self { add_prefix_space: true, trim_offsets: true, + use_regex: true, } } } impl ByteLevel { - pub fn new(add_prefix_space: bool, trim_offsets: bool) -> Self { + pub fn new(add_prefix_space: bool, trim_offsets: bool, use_regex: bool) -> Self { Self { add_prefix_space, trim_offsets, + use_regex, } } @@ -87,6 +98,12 @@ impl ByteLevel { self.trim_offsets = v; self } + + #[must_use] + pub fn use_regex(mut self, v: bool) -> Self { + self.use_regex = v; + self + } } /// As a `PreTokenizer`, `ByteLevel` is in charge of transforming all the unicode characters into @@ -99,7 +116,11 @@ impl PreTokenizer for ByteLevel { if self.add_prefix_space && !normalized.get().starts_with(' ') { normalized.prepend(" "); } - normalized.split(re_ref, SplitDelimiterBehavior::Isolated) + if self.use_regex { + normalized.split(re_ref, SplitDelimiterBehavior::Isolated) + } else { + Ok(vec![normalized]) + } })?; pretokenized.normalize(|normalized| { let s = normalized.get(); @@ -247,6 +268,21 @@ mod tests { ); } + #[test] + fn pre_tokenization_no_regex() { + let bytelevel = ByteLevel::default().use_regex(false); + let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into(); + bytelevel.pre_tokenize(&mut pretokenized).unwrap(); + assert_eq!( + pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(), + vec![("ĠHelloĠmyĠfriend,ĠhowĠisĠyourĠdayĠgoing?", (0, 39))] + ); + } + #[test] fn decoding() { let bytelevel = ByteLevel::default().add_prefix_space(false); @@ -513,4 +549,27 @@ mod tests { vec!["Hello there dear friend! [PA D]"] ); } + + #[test] + fn deserialization() { + // Before use_regex + let byte_level: ByteLevel = serde_json::from_str( + r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false}"#, + ) + .unwrap(); + assert!(byte_level.use_regex); + + // Loading works, new future BC test. + let byte_level: ByteLevel = serde_json::from_str( + r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": true}"#, + ) + .unwrap(); + assert!(byte_level.use_regex); + + let byte_level: ByteLevel = serde_json::from_str( + r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": false}"#, + ) + .unwrap(); + assert!(!byte_level.use_regex); + } }