Add content to Strip decoder to allow decoding mid tokens. (#1199)

* Add `content` to Strip decoder to allow decoding mid tokens.

* Stub.

* Clippy.
This commit is contained in:
Nicolas Patry
2023-03-24 10:14:49 +01:00
committed by GitHub
parent 8a6a8dc9d5
commit 3aaf4946b3
7 changed files with 78 additions and 39 deletions

View File

@ -42,10 +42,11 @@ export function fuseDecoder(): Decoder;
/**
* Instantiate a new Strip Decoder
* @param [content] The character to strip
* @param [left] The number of chars to remove from the left of each token
* @param [right] The number of chars to remove from the right of each token
*/
export function stripDecoder(left: number, right: number): Decoder;
export function stripDecoder(content: string, left: number, right: number): Decoder;
/**
* Instantiate a new Metaspace

View File

@ -65,11 +65,13 @@ describe("fuseDecoder", () => {
describe("stripDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(stripDecoder(0, 0)).toBeDefined();
expect(stripDecoder("_", 0, 0)).toBeDefined();
});
it("can decode arrays of strings", () => {
expect(stripDecoder(1, 0).decode(["Hel", "lo"])).toEqual("elo");
expect(stripDecoder("_", 1, 0).decode(["_Hel", "lo", "__there"])).toEqual(
"Hello_there"
);
});
});

View File

@ -104,14 +104,15 @@ fn fuse(mut cx: FunctionContext) -> JsResult<JsDecoder> {
Ok(decoder)
}
/// strip()
/// strip(content: char, left: usize, right: usize)
fn strip(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let left: usize = cx.extract(0)?;
let right: usize = cx.extract(1)?;
let content: char = cx.extract(0)?;
let left: usize = cx.extract(1)?;
let right: usize = cx.extract(2)?;
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock();
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
tk::decoders::strip::Strip::new(left, right).into(),
tk::decoders::strip::Strip::new(content, left, right).into(),
));
Ok(decoder)
}

View File

@ -226,7 +226,7 @@ class Strip(Decoder):
Strips n left characters of each token, or n right characters of each token
"""
def __init__(self, left=0, right=0):
def __init__(self, content, left=0, right=0):
pass
def decode(self, tokens):
"""

View File

@ -261,34 +261,44 @@ impl PyFuseDec {
/// Strip normalizer
/// Strips n left characters of each token, or n right characters of each token
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Strip")]
#[pyo3(text_signature = "(self, left=0, right=0)")]
#[pyo3(text_signature = "(self, content, left=0, right=0)")]
pub struct PyStrip {}
#[pymethods]
impl PyStrip {
#[getter]
fn get_left(self_: PyRef<Self>) -> usize {
getter!(self_, Strip, left)
fn get_start(self_: PyRef<Self>) -> usize {
getter!(self_, Strip, start)
}
#[setter]
fn set_left(self_: PyRef<Self>, left: usize) {
setter!(self_, Strip, left, left)
fn set_start(self_: PyRef<Self>, start: usize) {
setter!(self_, Strip, start, start)
}
#[getter]
fn get_right(self_: PyRef<Self>) -> usize {
getter!(self_, Strip, right)
fn get_stop(self_: PyRef<Self>) -> usize {
getter!(self_, Strip, stop)
}
#[setter]
fn set_right(self_: PyRef<Self>, right: usize) {
setter!(self_, Strip, right, right)
fn set_stop(self_: PyRef<Self>, stop: usize) {
setter!(self_, Strip, stop, stop)
}
#[getter]
fn get_content(self_: PyRef<Self>) -> char {
getter!(self_, Strip, content)
}
#[setter]
fn set_content(self_: PyRef<Self>, content: char) {
setter!(self_, Strip, content, content)
}
#[new]
#[pyo3(signature = (left=0, right=0))]
fn new(left: usize, right: usize) -> (Self, PyDecoder) {
(PyStrip {}, Strip::new(left, right).into())
#[pyo3(signature = (content=' ', left=0, right=0))]
fn new(content: char, left: usize, right: usize) -> (Self, PyDecoder) {
(PyStrip {}, Strip::new(content, left, right).into())
}
}

View File

@ -111,13 +111,13 @@ class TestFuse:
class TestStrip:
def test_instantiate(self):
assert Strip(left=0, right=0) is not None
assert isinstance(Strip(left=0, right=0), Decoder)
assert isinstance(Strip(left=0, right=0), Strip)
assert isinstance(pickle.loads(pickle.dumps(Strip(left=0, right=0))), Strip)
assert isinstance(Strip(content="_", left=0, right=0), Decoder)
assert isinstance(Strip(content="_", left=0, right=0), Strip)
assert isinstance(pickle.loads(pickle.dumps(Strip(content="_", left=0, right=0))), Strip)
def test_decoding(self):
decoder = Strip(left=1, right=0)
assert decoder.decode(["My", " na", "me"]) == "ynae"
decoder = Strip(content="_", left=1, right=0)
assert decoder.decode(["_My", " na", "me", " _-", "__-"]) == "My name _-_-"
class TestMetaspace:

View File

@ -9,13 +9,18 @@ use serde::{Deserialize, Serialize};
#[serde(tag = "type")]
#[non_exhaustive]
pub struct Strip {
pub left: usize,
pub right: usize,
pub content: char,
pub start: usize,
pub stop: usize,
}
impl Strip {
pub fn new(left: usize, right: usize) -> Self {
Self { left, right }
pub fn new(content: char, start: usize, stop: usize) -> Self {
Self {
content,
start,
stop,
}
}
}
@ -24,11 +29,31 @@ impl Decoder for Strip {
Ok(tokens
.into_iter()
.map(|token| {
token
.chars()
.skip(self.left)
.take(token.len() - self.left - self.right)
.collect()
let chars: Vec<char> = token.chars().collect();
let mut start_cut = 0;
for (i, &c) in chars.iter().enumerate().take(self.start) {
if c == self.content {
start_cut = i + 1;
continue;
} else {
break;
}
}
let mut stop_cut = chars.len();
for i in 0..self.stop {
let index = chars.len() - i - 1;
if chars[index] == self.content {
stop_cut = index;
continue;
} else {
break;
}
}
let new_token: String = chars[start_cut..stop_cut].iter().collect();
new_token
})
.collect())
}
@ -40,16 +65,16 @@ mod tests {
#[test]
fn decode() {
let decoder = Strip::new(1, 0);
let decoder = Strip::new('H', 1, 0);
let res = decoder
.decode_chain(vec!["Hey".into(), " friend!".into()])
.decode_chain(vec!["Hey".into(), " friend!".into(), "HHH".into()])
.unwrap();
assert_eq!(res, vec!["ey", "friend!"]);
assert_eq!(res, vec!["ey", " friend!", "HH"]);
let decoder = Strip::new(0, 1);
let decoder = Strip::new('y', 0, 1);
let res = decoder
.decode_chain(vec!["Hey".into(), " friend!".into()])
.unwrap();
assert_eq!(res, vec!["He", " friend"]);
assert_eq!(res, vec!["He", " friend!"]);
}
}