diff --git a/bindings/node/lib/bindings/pre-tokenizers.d.ts b/bindings/node/lib/bindings/pre-tokenizers.d.ts index 17971486..617d6560 100644 --- a/bindings/node/lib/bindings/pre-tokenizers.d.ts +++ b/bindings/node/lib/bindings/pre-tokenizers.d.ts @@ -90,10 +90,14 @@ export function charDelimiterSplitPreTokenizer(delimiter: string): PreTokenizer; /** * Returns a new Punctuation PreTokenizer. - * This pre-tokenizer splits tokens on punctuation. - * Each occurrence of a punctuation character will be treated separately. + * This pre-tokenizer splits tokens on punctuation according to the provided behavior. + * Each occurrence of a punctuation character is treated separately. + * + * @param [behavior="isolated"] The behavior to use when splitting. + * Choices: "removed", "isolated", "mergedWithPrevious", "mergedWithNext", + * "contiguous" */ -export function punctuationPreTokenizer(): PreTokenizer; +export function punctuationPreTokenizer(behavior?: string): PreTokenizer; /** * Returns a new Sequence PreTokenizer. diff --git a/bindings/node/lib/bindings/pre-tokenizers.test.ts b/bindings/node/lib/bindings/pre-tokenizers.test.ts index 39a6eb40..6850acb9 100644 --- a/bindings/node/lib/bindings/pre-tokenizers.test.ts +++ b/bindings/node/lib/bindings/pre-tokenizers.test.ts @@ -43,6 +43,11 @@ describe("punctuationPreTokenizer", () => { const processor = punctuationPreTokenizer(); expect(processor.constructor.name).toEqual("PreTokenizer"); }); + + it("instantiates correctly with non-default split delimeter", () => { + const processor = punctuationPreTokenizer("removed"); + expect(processor.constructor.name).toEqual("PreTokenizer"); + }); }); describe("splitPreTokenizer", () => { diff --git a/bindings/node/native/src/pre_tokenizers.rs b/bindings/node/native/src/pre_tokenizers.rs index 3477f8d8..c350bd04 100644 --- a/bindings/node/native/src/pre_tokenizers.rs +++ b/bindings/node/native/src/pre_tokenizers.rs @@ -203,9 +203,15 @@ fn split(mut cx: FunctionContext) -> JsResult { /// punctuation() fn punctuation(mut cx: FunctionContext) -> JsResult { + let behavior: JsSplitDelimiterBehavior = cx + .extract_opt::(0)? + .unwrap_or(JsSplitDelimiterBehavior(SplitDelimiterBehavior::Isolated)); + let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let guard = cx.lock(); - pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::punctuation::Punctuation.into()); + pretok.borrow_mut(&guard).pretok = + Some(tk::pre_tokenizers::punctuation::Punctuation::new(behavior.into()).into()); + Ok(pretok) } diff --git a/bindings/python/CHANGELOG.md b/bindings/python/CHANGELOG.md index db38e705..d2077a56 100644 --- a/bindings/python/CHANGELOG.md +++ b/bindings/python/CHANGELOG.md @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added +- [#657]: Add SplitDelimiterBehavior customization to Punctuation constructor + ## [0.10.3] ### Fixed @@ -326,6 +331,7 @@ delimiter (Works like `.split(delimiter)`) [#693]: https://github.com/huggingface/tokenizers/pull/693 [#686]: https://github.com/huggingface/tokenizers/pull/686 [#674]: https://github.com/huggingface/tokenizers/pull/674 +[#657]: https://github.com/huggingface/tokenizers/pull/657 [#656]: https://github.com/huggingface/tokenizers/pull/656 [#652]: https://github.com/huggingface/tokenizers/pull/652 [#621]: https://github.com/huggingface/tokenizers/pull/621 diff --git a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi index b4764e12..e3ffbbbc 100644 --- a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi @@ -308,10 +308,16 @@ class Metaspace(PreTokenizer): class Punctuation(PreTokenizer): """ - This pre-tokenizer simply splits on punctuation as individual characters.` + This pre-tokenizer simply splits on punctuation as individual characters. + + Args: + behavior (:class:`~tokenizers.SplitDelimiterBehavior`): + The behavior to use when splitting. + Choices: "removed", "isolated" (default), "merged_with_previous", "merged_with_next", + "contiguous" """ - def __init__(self): + def __init__(self, behavior="isolated"): pass def pre_tokenize(self, pretok): """ diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index c71b64f0..77840012 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -6,6 +6,7 @@ use pyo3::types::*; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use tk::normalizer::SplitDelimiterBehavior; use tk::pre_tokenizers::bert::BertPreTokenizer; use tk::pre_tokenizers::byte_level::ByteLevel; use tk::pre_tokenizers::delimiter::CharDelimiterSplit; @@ -384,15 +385,22 @@ impl PyBertPreTokenizer { } } -/// This pre-tokenizer simply splits on punctuation as individual characters.` +/// This pre-tokenizer simply splits on punctuation as individual characters. +/// +/// Args: +/// behavior (:class:`~tokenizers.SplitDelimiterBehavior`): +/// The behavior to use when splitting. +/// Choices: "removed", "isolated" (default), "merged_with_previous", "merged_with_next", +/// "contiguous" #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Punctuation)] -#[text_signature = "(self)"] +#[text_signature = "(self, behavior=\"isolated\")"] pub struct PyPunctuation {} #[pymethods] impl PyPunctuation { #[new] - fn new() -> (Self, PyPreTokenizer) { - (PyPunctuation {}, Punctuation.into()) + #[args(behavior = "PySplitDelimiterBehavior(SplitDelimiterBehavior::Isolated)")] + fn new(behavior: PySplitDelimiterBehavior) -> (Self, PyPreTokenizer) { + (PyPunctuation {}, Punctuation::new(behavior.into()).into()) } } diff --git a/bindings/python/src/utils/normalization.rs b/bindings/python/src/utils/normalization.rs index e0e97d65..59380b9d 100644 --- a/bindings/python/src/utils/normalization.rs +++ b/bindings/python/src/utils/normalization.rs @@ -92,7 +92,7 @@ impl PyRange<'_> { } #[derive(Clone)] -pub struct PySplitDelimiterBehavior(SplitDelimiterBehavior); +pub struct PySplitDelimiterBehavior(pub SplitDelimiterBehavior); impl FromPyObject<'_> for PySplitDelimiterBehavior { fn extract(obj: &PyAny) -> PyResult { diff --git a/bindings/python/tests/bindings/test_pre_tokenizers.py b/bindings/python/tests/bindings/test_pre_tokenizers.py index b99aa5ca..e73b6ba5 100644 --- a/bindings/python/tests/bindings/test_pre_tokenizers.py +++ b/bindings/python/tests/bindings/test_pre_tokenizers.py @@ -132,6 +132,7 @@ class TestCharDelimiterSplit: class TestPunctuation: def test_instantiate(self): assert Punctuation() is not None + assert Punctuation("removed") is not None assert isinstance(Punctuation(), PreTokenizer) assert isinstance(Punctuation(), Punctuation) assert isinstance(pickle.loads(pickle.dumps(Punctuation())), Punctuation) diff --git a/tokenizers/src/pre_tokenizers/punctuation.rs b/tokenizers/src/pre_tokenizers/punctuation.rs index b8efa6ed..eee7058f 100644 --- a/tokenizers/src/pre_tokenizers/punctuation.rs +++ b/tokenizers/src/pre_tokenizers/punctuation.rs @@ -1,3 +1,5 @@ +use serde::{Deserialize, Serialize}; + use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; use unicode_categories::UnicodeCategories; @@ -5,13 +7,27 @@ fn is_punc(x: char) -> bool { char::is_ascii_punctuation(&x) || x.is_punctuation() } -#[derive(Copy, Clone, Debug)] -pub struct Punctuation; -impl_serde_unit_struct!(PunctuationVisitor, Punctuation); +#[derive(Serialize, Deserialize, Copy, Clone, Debug)] +#[serde(tag = "type")] +pub struct Punctuation { + behavior: SplitDelimiterBehavior, +} + +impl Punctuation { + pub fn new(behavior: SplitDelimiterBehavior) -> Self { + Self { behavior } + } +} + +impl Default for Punctuation { + fn default() -> Self { + Self::new(SplitDelimiterBehavior::Isolated) + } +} impl PreTokenizer for Punctuation { fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { - pretokenized.split(|_, s| s.split(is_punc, SplitDelimiterBehavior::Isolated)) + pretokenized.split(|_, s| s.split(is_punc, self.behavior)) } } @@ -22,7 +38,7 @@ mod tests { #[test] fn punctuation_basic() { - let pretok = Punctuation; + let pretok = Punctuation::default(); let mut pretokenized: PreTokenizedString = "Hey friend! How are you?!?".into(); pretok.pre_tokenize(&mut pretokenized).unwrap(); assert_eq!( diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index ef57fcab..7fd43c94 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -33,7 +33,7 @@ mod tests { fn sequence_basic() { let pretokenizers = vec![ PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit), - PreTokenizerWrapper::Punctuation(Punctuation), + PreTokenizerWrapper::Punctuation(Punctuation::default()), ]; let pretok = Sequence::new(pretokenizers); let mut pretokenized: PreTokenizedString = "Hey friend! How are you?!?".into();