From 3aaf4946b3c82e7a04db6ecde7c8bb4e474e54af Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 24 Mar 2023 10:14:49 +0100 Subject: [PATCH] Add `content` to Strip decoder to allow decoding mid tokens. (#1199) * Add `content` to Strip decoder to allow decoding mid tokens. * Stub. * Clippy. --- bindings/node/lib/bindings/decoders.d.ts | 3 +- bindings/node/lib/bindings/decoders.test.ts | 6 ++- bindings/node/native/src/decoders.rs | 9 ++-- .../py_src/tokenizers/decoders/__init__.pyi | 2 +- bindings/python/src/decoders.rs | 34 +++++++----- .../python/tests/bindings/test_decoders.py | 10 ++-- tokenizers/src/decoders/strip.rs | 53 ++++++++++++++----- 7 files changed, 78 insertions(+), 39 deletions(-) diff --git a/bindings/node/lib/bindings/decoders.d.ts b/bindings/node/lib/bindings/decoders.d.ts index be9dc3f7..9543cb7a 100644 --- a/bindings/node/lib/bindings/decoders.d.ts +++ b/bindings/node/lib/bindings/decoders.d.ts @@ -42,10 +42,11 @@ export function fuseDecoder(): Decoder; /** * Instantiate a new Strip Decoder + * @param [content] The character to strip * @param [left] The number of chars to remove from the left of each token * @param [right] The number of chars to remove from the right of each token */ -export function stripDecoder(left: number, right: number): Decoder; +export function stripDecoder(content: string, left: number, right: number): Decoder; /** * Instantiate a new Metaspace diff --git a/bindings/node/lib/bindings/decoders.test.ts b/bindings/node/lib/bindings/decoders.test.ts index 624803b1..031e5f0c 100644 --- a/bindings/node/lib/bindings/decoders.test.ts +++ b/bindings/node/lib/bindings/decoders.test.ts @@ -65,11 +65,13 @@ describe("fuseDecoder", () => { describe("stripDecoder", () => { it("accepts `undefined` as first parameter", () => { - expect(stripDecoder(0, 0)).toBeDefined(); + expect(stripDecoder("_", 0, 0)).toBeDefined(); }); it("can decode arrays of strings", () => { - expect(stripDecoder(1, 0).decode(["Hel", "lo"])).toEqual("elo"); + expect(stripDecoder("_", 1, 0).decode(["_Hel", "lo", "__there"])).toEqual( + "Hello_there" + ); }); }); diff --git a/bindings/node/native/src/decoders.rs b/bindings/node/native/src/decoders.rs index fa3c712e..2c0b3709 100644 --- a/bindings/node/native/src/decoders.rs +++ b/bindings/node/native/src/decoders.rs @@ -104,14 +104,15 @@ fn fuse(mut cx: FunctionContext) -> JsResult { Ok(decoder) } -/// strip() +/// strip(content: char, left: usize, right: usize) fn strip(mut cx: FunctionContext) -> JsResult { - let left: usize = cx.extract(0)?; - let right: usize = cx.extract(1)?; + let content: char = cx.extract(0)?; + let left: usize = cx.extract(1)?; + let right: usize = cx.extract(2)?; let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?; let guard = cx.lock(); decoder.borrow_mut(&guard).decoder = Some(Arc::new( - tk::decoders::strip::Strip::new(left, right).into(), + tk::decoders::strip::Strip::new(content, left, right).into(), )); Ok(decoder) } diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.pyi b/bindings/python/py_src/tokenizers/decoders/__init__.pyi index 21fe746a..83a0e827 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.pyi +++ b/bindings/python/py_src/tokenizers/decoders/__init__.pyi @@ -226,7 +226,7 @@ class Strip(Decoder): Strips n left characters of each token, or n right characters of each token """ - def __init__(self, left=0, right=0): + def __init__(self, content, left=0, right=0): pass def decode(self, tokens): """ diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index e1b5bd79..f6e0388c 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -261,34 +261,44 @@ impl PyFuseDec { /// Strip normalizer /// Strips n left characters of each token, or n right characters of each token #[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Strip")] -#[pyo3(text_signature = "(self, left=0, right=0)")] +#[pyo3(text_signature = "(self, content, left=0, right=0)")] pub struct PyStrip {} #[pymethods] impl PyStrip { #[getter] - fn get_left(self_: PyRef) -> usize { - getter!(self_, Strip, left) + fn get_start(self_: PyRef) -> usize { + getter!(self_, Strip, start) } #[setter] - fn set_left(self_: PyRef, left: usize) { - setter!(self_, Strip, left, left) + fn set_start(self_: PyRef, start: usize) { + setter!(self_, Strip, start, start) } #[getter] - fn get_right(self_: PyRef) -> usize { - getter!(self_, Strip, right) + fn get_stop(self_: PyRef) -> usize { + getter!(self_, Strip, stop) } #[setter] - fn set_right(self_: PyRef, right: usize) { - setter!(self_, Strip, right, right) + fn set_stop(self_: PyRef, stop: usize) { + setter!(self_, Strip, stop, stop) + } + + #[getter] + fn get_content(self_: PyRef) -> char { + getter!(self_, Strip, content) + } + + #[setter] + fn set_content(self_: PyRef, content: char) { + setter!(self_, Strip, content, content) } #[new] - #[pyo3(signature = (left=0, right=0))] - fn new(left: usize, right: usize) -> (Self, PyDecoder) { - (PyStrip {}, Strip::new(left, right).into()) + #[pyo3(signature = (content=' ', left=0, right=0))] + fn new(content: char, left: usize, right: usize) -> (Self, PyDecoder) { + (PyStrip {}, Strip::new(content, left, right).into()) } } diff --git a/bindings/python/tests/bindings/test_decoders.py b/bindings/python/tests/bindings/test_decoders.py index fe21e022..c8b1396b 100644 --- a/bindings/python/tests/bindings/test_decoders.py +++ b/bindings/python/tests/bindings/test_decoders.py @@ -111,13 +111,13 @@ class TestFuse: class TestStrip: def test_instantiate(self): assert Strip(left=0, right=0) is not None - assert isinstance(Strip(left=0, right=0), Decoder) - assert isinstance(Strip(left=0, right=0), Strip) - assert isinstance(pickle.loads(pickle.dumps(Strip(left=0, right=0))), Strip) + assert isinstance(Strip(content="_", left=0, right=0), Decoder) + assert isinstance(Strip(content="_", left=0, right=0), Strip) + assert isinstance(pickle.loads(pickle.dumps(Strip(content="_", left=0, right=0))), Strip) def test_decoding(self): - decoder = Strip(left=1, right=0) - assert decoder.decode(["My", " na", "me"]) == "ynae" + decoder = Strip(content="_", left=1, right=0) + assert decoder.decode(["_My", " na", "me", " _-", "__-"]) == "My name _-_-" class TestMetaspace: diff --git a/tokenizers/src/decoders/strip.rs b/tokenizers/src/decoders/strip.rs index 0f6e0426..b095fc37 100644 --- a/tokenizers/src/decoders/strip.rs +++ b/tokenizers/src/decoders/strip.rs @@ -9,13 +9,18 @@ use serde::{Deserialize, Serialize}; #[serde(tag = "type")] #[non_exhaustive] pub struct Strip { - pub left: usize, - pub right: usize, + pub content: char, + pub start: usize, + pub stop: usize, } impl Strip { - pub fn new(left: usize, right: usize) -> Self { - Self { left, right } + pub fn new(content: char, start: usize, stop: usize) -> Self { + Self { + content, + start, + stop, + } } } @@ -24,11 +29,31 @@ impl Decoder for Strip { Ok(tokens .into_iter() .map(|token| { - token - .chars() - .skip(self.left) - .take(token.len() - self.left - self.right) - .collect() + let chars: Vec = token.chars().collect(); + + let mut start_cut = 0; + for (i, &c) in chars.iter().enumerate().take(self.start) { + if c == self.content { + start_cut = i + 1; + continue; + } else { + break; + } + } + + let mut stop_cut = chars.len(); + for i in 0..self.stop { + let index = chars.len() - i - 1; + if chars[index] == self.content { + stop_cut = index; + continue; + } else { + break; + } + } + + let new_token: String = chars[start_cut..stop_cut].iter().collect(); + new_token }) .collect()) } @@ -40,16 +65,16 @@ mod tests { #[test] fn decode() { - let decoder = Strip::new(1, 0); + let decoder = Strip::new('H', 1, 0); let res = decoder - .decode_chain(vec!["Hey".into(), " friend!".into()]) + .decode_chain(vec!["Hey".into(), " friend!".into(), "HHH".into()]) .unwrap(); - assert_eq!(res, vec!["ey", "friend!"]); + assert_eq!(res, vec!["ey", " friend!", "HH"]); - let decoder = Strip::new(0, 1); + let decoder = Strip::new('y', 0, 1); let res = decoder .decode_chain(vec!["Hey".into(), " friend!".into()]) .unwrap(); - assert_eq!(res, vec!["He", " friend"]); + assert_eq!(res, vec!["He", " friend!"]); } }