Add content to Strip decoder to allow decoding mid tokens. (#1199)

* Add `content` to Strip decoder to allow decoding mid tokens.

* Stub.

* Clippy.
This commit is contained in:
Nicolas Patry
2023-03-24 10:14:49 +01:00
committed by GitHub
parent 8a6a8dc9d5
commit 3aaf4946b3
7 changed files with 78 additions and 39 deletions

View File

@ -42,10 +42,11 @@ export function fuseDecoder(): Decoder;
/** /**
* Instantiate a new Strip Decoder * Instantiate a new Strip Decoder
* @param [content] The character to strip
* @param [left] The number of chars to remove from the left of each token * @param [left] The number of chars to remove from the left of each token
* @param [right] The number of chars to remove from the right of each token * @param [right] The number of chars to remove from the right of each token
*/ */
export function stripDecoder(left: number, right: number): Decoder; export function stripDecoder(content: string, left: number, right: number): Decoder;
/** /**
* Instantiate a new Metaspace * Instantiate a new Metaspace

View File

@ -65,11 +65,13 @@ describe("fuseDecoder", () => {
describe("stripDecoder", () => { describe("stripDecoder", () => {
it("accepts `undefined` as first parameter", () => { it("accepts `undefined` as first parameter", () => {
expect(stripDecoder(0, 0)).toBeDefined(); expect(stripDecoder("_", 0, 0)).toBeDefined();
}); });
it("can decode arrays of strings", () => { it("can decode arrays of strings", () => {
expect(stripDecoder(1, 0).decode(["Hel", "lo"])).toEqual("elo"); expect(stripDecoder("_", 1, 0).decode(["_Hel", "lo", "__there"])).toEqual(
"Hello_there"
);
}); });
}); });

View File

@ -104,14 +104,15 @@ fn fuse(mut cx: FunctionContext) -> JsResult<JsDecoder> {
Ok(decoder) Ok(decoder)
} }
/// strip() /// strip(content: char, left: usize, right: usize)
fn strip(mut cx: FunctionContext) -> JsResult<JsDecoder> { fn strip(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let left: usize = cx.extract(0)?; let content: char = cx.extract(0)?;
let right: usize = cx.extract(1)?; let left: usize = cx.extract(1)?;
let right: usize = cx.extract(2)?;
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?; let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock(); let guard = cx.lock();
decoder.borrow_mut(&guard).decoder = Some(Arc::new( decoder.borrow_mut(&guard).decoder = Some(Arc::new(
tk::decoders::strip::Strip::new(left, right).into(), tk::decoders::strip::Strip::new(content, left, right).into(),
)); ));
Ok(decoder) Ok(decoder)
} }

View File

@ -226,7 +226,7 @@ class Strip(Decoder):
Strips n left characters of each token, or n right characters of each token Strips n left characters of each token, or n right characters of each token
""" """
def __init__(self, left=0, right=0): def __init__(self, content, left=0, right=0):
pass pass
def decode(self, tokens): def decode(self, tokens):
""" """

View File

@ -261,34 +261,44 @@ impl PyFuseDec {
/// Strip normalizer /// Strip normalizer
/// Strips n left characters of each token, or n right characters of each token /// Strips n left characters of each token, or n right characters of each token
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Strip")] #[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Strip")]
#[pyo3(text_signature = "(self, left=0, right=0)")] #[pyo3(text_signature = "(self, content, left=0, right=0)")]
pub struct PyStrip {} pub struct PyStrip {}
#[pymethods] #[pymethods]
impl PyStrip { impl PyStrip {
#[getter] #[getter]
fn get_left(self_: PyRef<Self>) -> usize { fn get_start(self_: PyRef<Self>) -> usize {
getter!(self_, Strip, left) getter!(self_, Strip, start)
} }
#[setter] #[setter]
fn set_left(self_: PyRef<Self>, left: usize) { fn set_start(self_: PyRef<Self>, start: usize) {
setter!(self_, Strip, left, left) setter!(self_, Strip, start, start)
} }
#[getter] #[getter]
fn get_right(self_: PyRef<Self>) -> usize { fn get_stop(self_: PyRef<Self>) -> usize {
getter!(self_, Strip, right) getter!(self_, Strip, stop)
} }
#[setter] #[setter]
fn set_right(self_: PyRef<Self>, right: usize) { fn set_stop(self_: PyRef<Self>, stop: usize) {
setter!(self_, Strip, right, right) setter!(self_, Strip, stop, stop)
}
#[getter]
fn get_content(self_: PyRef<Self>) -> char {
getter!(self_, Strip, content)
}
#[setter]
fn set_content(self_: PyRef<Self>, content: char) {
setter!(self_, Strip, content, content)
} }
#[new] #[new]
#[pyo3(signature = (left=0, right=0))] #[pyo3(signature = (content=' ', left=0, right=0))]
fn new(left: usize, right: usize) -> (Self, PyDecoder) { fn new(content: char, left: usize, right: usize) -> (Self, PyDecoder) {
(PyStrip {}, Strip::new(left, right).into()) (PyStrip {}, Strip::new(content, left, right).into())
} }
} }

View File

@ -111,13 +111,13 @@ class TestFuse:
class TestStrip: class TestStrip:
def test_instantiate(self): def test_instantiate(self):
assert Strip(left=0, right=0) is not None assert Strip(left=0, right=0) is not None
assert isinstance(Strip(left=0, right=0), Decoder) assert isinstance(Strip(content="_", left=0, right=0), Decoder)
assert isinstance(Strip(left=0, right=0), Strip) assert isinstance(Strip(content="_", left=0, right=0), Strip)
assert isinstance(pickle.loads(pickle.dumps(Strip(left=0, right=0))), Strip) assert isinstance(pickle.loads(pickle.dumps(Strip(content="_", left=0, right=0))), Strip)
def test_decoding(self): def test_decoding(self):
decoder = Strip(left=1, right=0) decoder = Strip(content="_", left=1, right=0)
assert decoder.decode(["My", " na", "me"]) == "ynae" assert decoder.decode(["_My", " na", "me", " _-", "__-"]) == "My name _-_-"
class TestMetaspace: class TestMetaspace:

View File

@ -9,13 +9,18 @@ use serde::{Deserialize, Serialize};
#[serde(tag = "type")] #[serde(tag = "type")]
#[non_exhaustive] #[non_exhaustive]
pub struct Strip { pub struct Strip {
pub left: usize, pub content: char,
pub right: usize, pub start: usize,
pub stop: usize,
} }
impl Strip { impl Strip {
pub fn new(left: usize, right: usize) -> Self { pub fn new(content: char, start: usize, stop: usize) -> Self {
Self { left, right } Self {
content,
start,
stop,
}
} }
} }
@ -24,11 +29,31 @@ impl Decoder for Strip {
Ok(tokens Ok(tokens
.into_iter() .into_iter()
.map(|token| { .map(|token| {
token let chars: Vec<char> = token.chars().collect();
.chars()
.skip(self.left) let mut start_cut = 0;
.take(token.len() - self.left - self.right) for (i, &c) in chars.iter().enumerate().take(self.start) {
.collect() if c == self.content {
start_cut = i + 1;
continue;
} else {
break;
}
}
let mut stop_cut = chars.len();
for i in 0..self.stop {
let index = chars.len() - i - 1;
if chars[index] == self.content {
stop_cut = index;
continue;
} else {
break;
}
}
let new_token: String = chars[start_cut..stop_cut].iter().collect();
new_token
}) })
.collect()) .collect())
} }
@ -40,16 +65,16 @@ mod tests {
#[test] #[test]
fn decode() { fn decode() {
let decoder = Strip::new(1, 0); let decoder = Strip::new('H', 1, 0);
let res = decoder let res = decoder
.decode_chain(vec!["Hey".into(), " friend!".into()]) .decode_chain(vec!["Hey".into(), " friend!".into(), "HHH".into()])
.unwrap(); .unwrap();
assert_eq!(res, vec!["ey", "friend!"]); assert_eq!(res, vec!["ey", " friend!", "HH"]);
let decoder = Strip::new(0, 1); let decoder = Strip::new('y', 0, 1);
let res = decoder let res = decoder
.decode_chain(vec!["Hey".into(), " friend!".into()]) .decode_chain(vec!["Hey".into(), " friend!".into()])
.unwrap(); .unwrap();
assert_eq!(res, vec!["He", " friend"]); assert_eq!(res, vec!["He", " friend!"]);
} }
} }