diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 3b1b1bbf..b494e408 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -17,7 +17,6 @@ env_logger = "0.11" pyo3 = { version = "0.21" } numpy = "0.21" ndarray = "0.15" -onig = { version = "6.4", default-features = false } itertools = "0.12" [dependencies.tokenizers] diff --git a/bindings/python/src/utils/regex.rs b/bindings/python/src/utils/regex.rs index 82893ca7..16f70682 100644 --- a/bindings/python/src/utils/regex.rs +++ b/bindings/python/src/utils/regex.rs @@ -1,11 +1,11 @@ -use onig::Regex; use pyo3::exceptions; use pyo3::prelude::*; +use tk::utils::SysRegex; /// Instantiate a new Regex with the given pattern #[pyclass(module = "tokenizers", name = "Regex")] pub struct PyRegex { - pub inner: Regex, + pub inner: SysRegex, pub pattern: String, } @@ -15,8 +15,8 @@ impl PyRegex { #[pyo3(text_signature = "(self, pattern)")] fn new(s: &str) -> PyResult { Ok(Self { - inner: Regex::new(s) - .map_err(|e| exceptions::PyException::new_err(e.description().to_owned()))?, + inner: SysRegex::new(s) + .map_err(|e| exceptions::PyException::new_err(e.to_string().to_owned()))?, pattern: s.to_owned(), }) } diff --git a/tokenizers/src/utils/fancy.rs b/tokenizers/src/utils/fancy.rs index 9d94fd7a..9d44bc74 100644 --- a/tokenizers/src/utils/fancy.rs +++ b/tokenizers/src/utils/fancy.rs @@ -1,3 +1,5 @@ +use crate::tokenizer::pattern::Pattern; +use crate::Offsets; use fancy_regex::Regex; use std::error::Error; @@ -31,3 +33,31 @@ impl<'r, 't> Iterator for Matches<'r, 't> { } } } + +impl Pattern for &Regex { + fn find_matches( + &self, + inside: &str, + ) -> Result, Box> { + if inside.is_empty() { + return Ok(vec![((0, 0), false)]); + } + + let mut prev = 0; + let mut splits = Vec::with_capacity(inside.len()); + for match_ in self.find_iter(inside) { + let match_ = match_?; + let start = match_.start(); + let end = match_.end(); + if prev != start { + splits.push(((prev, start), false)); + } + splits.push(((start, end), true)); + prev = end; + } + if prev != inside.len() { + splits.push(((prev, inside.len()), false)) + } + Ok(splits) + } +}