mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add content
to Strip decoder to allow decoding mid tokens. (#1199)
* Add `content` to Strip decoder to allow decoding mid tokens. * Stub. * Clippy.
This commit is contained in:
3
bindings/node/lib/bindings/decoders.d.ts
vendored
3
bindings/node/lib/bindings/decoders.d.ts
vendored
@ -42,10 +42,11 @@ export function fuseDecoder(): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new Strip Decoder
|
||||
* @param [content] The character to strip
|
||||
* @param [left] The number of chars to remove from the left of each token
|
||||
* @param [right] The number of chars to remove from the right of each token
|
||||
*/
|
||||
export function stripDecoder(left: number, right: number): Decoder;
|
||||
export function stripDecoder(content: string, left: number, right: number): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new Metaspace
|
||||
|
@ -65,11 +65,13 @@ describe("fuseDecoder", () => {
|
||||
|
||||
describe("stripDecoder", () => {
|
||||
it("accepts `undefined` as first parameter", () => {
|
||||
expect(stripDecoder(0, 0)).toBeDefined();
|
||||
expect(stripDecoder("_", 0, 0)).toBeDefined();
|
||||
});
|
||||
|
||||
it("can decode arrays of strings", () => {
|
||||
expect(stripDecoder(1, 0).decode(["Hel", "lo"])).toEqual("elo");
|
||||
expect(stripDecoder("_", 1, 0).decode(["_Hel", "lo", "__there"])).toEqual(
|
||||
"Hello_there"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -104,14 +104,15 @@ fn fuse(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// strip()
|
||||
/// strip(content: char, left: usize, right: usize)
|
||||
fn strip(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
let left: usize = cx.extract(0)?;
|
||||
let right: usize = cx.extract(1)?;
|
||||
let content: char = cx.extract(0)?;
|
||||
let left: usize = cx.extract(1)?;
|
||||
let right: usize = cx.extract(2)?;
|
||||
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
|
||||
tk::decoders::strip::Strip::new(left, right).into(),
|
||||
tk::decoders::strip::Strip::new(content, left, right).into(),
|
||||
));
|
||||
Ok(decoder)
|
||||
}
|
||||
|
@ -226,7 +226,7 @@ class Strip(Decoder):
|
||||
Strips n left characters of each token, or n right characters of each token
|
||||
"""
|
||||
|
||||
def __init__(self, left=0, right=0):
|
||||
def __init__(self, content, left=0, right=0):
|
||||
pass
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
|
@ -261,34 +261,44 @@ impl PyFuseDec {
|
||||
/// Strip normalizer
|
||||
/// Strips n left characters of each token, or n right characters of each token
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Strip")]
|
||||
#[pyo3(text_signature = "(self, left=0, right=0)")]
|
||||
#[pyo3(text_signature = "(self, content, left=0, right=0)")]
|
||||
pub struct PyStrip {}
|
||||
#[pymethods]
|
||||
impl PyStrip {
|
||||
#[getter]
|
||||
fn get_left(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, Strip, left)
|
||||
fn get_start(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, Strip, start)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_left(self_: PyRef<Self>, left: usize) {
|
||||
setter!(self_, Strip, left, left)
|
||||
fn set_start(self_: PyRef<Self>, start: usize) {
|
||||
setter!(self_, Strip, start, start)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_right(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, Strip, right)
|
||||
fn get_stop(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, Strip, stop)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_right(self_: PyRef<Self>, right: usize) {
|
||||
setter!(self_, Strip, right, right)
|
||||
fn set_stop(self_: PyRef<Self>, stop: usize) {
|
||||
setter!(self_, Strip, stop, stop)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_content(self_: PyRef<Self>) -> char {
|
||||
getter!(self_, Strip, content)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_content(self_: PyRef<Self>, content: char) {
|
||||
setter!(self_, Strip, content, content)
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[pyo3(signature = (left=0, right=0))]
|
||||
fn new(left: usize, right: usize) -> (Self, PyDecoder) {
|
||||
(PyStrip {}, Strip::new(left, right).into())
|
||||
#[pyo3(signature = (content=' ', left=0, right=0))]
|
||||
fn new(content: char, left: usize, right: usize) -> (Self, PyDecoder) {
|
||||
(PyStrip {}, Strip::new(content, left, right).into())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,13 +111,13 @@ class TestFuse:
|
||||
class TestStrip:
|
||||
def test_instantiate(self):
|
||||
assert Strip(left=0, right=0) is not None
|
||||
assert isinstance(Strip(left=0, right=0), Decoder)
|
||||
assert isinstance(Strip(left=0, right=0), Strip)
|
||||
assert isinstance(pickle.loads(pickle.dumps(Strip(left=0, right=0))), Strip)
|
||||
assert isinstance(Strip(content="_", left=0, right=0), Decoder)
|
||||
assert isinstance(Strip(content="_", left=0, right=0), Strip)
|
||||
assert isinstance(pickle.loads(pickle.dumps(Strip(content="_", left=0, right=0))), Strip)
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = Strip(left=1, right=0)
|
||||
assert decoder.decode(["My", " na", "me"]) == "ynae"
|
||||
decoder = Strip(content="_", left=1, right=0)
|
||||
assert decoder.decode(["_My", " na", "me", " _-", "__-"]) == "My name _-_-"
|
||||
|
||||
|
||||
class TestMetaspace:
|
||||
|
@ -9,13 +9,18 @@ use serde::{Deserialize, Serialize};
|
||||
#[serde(tag = "type")]
|
||||
#[non_exhaustive]
|
||||
pub struct Strip {
|
||||
pub left: usize,
|
||||
pub right: usize,
|
||||
pub content: char,
|
||||
pub start: usize,
|
||||
pub stop: usize,
|
||||
}
|
||||
|
||||
impl Strip {
|
||||
pub fn new(left: usize, right: usize) -> Self {
|
||||
Self { left, right }
|
||||
pub fn new(content: char, start: usize, stop: usize) -> Self {
|
||||
Self {
|
||||
content,
|
||||
start,
|
||||
stop,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -24,11 +29,31 @@ impl Decoder for Strip {
|
||||
Ok(tokens
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
token
|
||||
.chars()
|
||||
.skip(self.left)
|
||||
.take(token.len() - self.left - self.right)
|
||||
.collect()
|
||||
let chars: Vec<char> = token.chars().collect();
|
||||
|
||||
let mut start_cut = 0;
|
||||
for (i, &c) in chars.iter().enumerate().take(self.start) {
|
||||
if c == self.content {
|
||||
start_cut = i + 1;
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut stop_cut = chars.len();
|
||||
for i in 0..self.stop {
|
||||
let index = chars.len() - i - 1;
|
||||
if chars[index] == self.content {
|
||||
stop_cut = index;
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let new_token: String = chars[start_cut..stop_cut].iter().collect();
|
||||
new_token
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
@ -40,16 +65,16 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn decode() {
|
||||
let decoder = Strip::new(1, 0);
|
||||
let decoder = Strip::new('H', 1, 0);
|
||||
let res = decoder
|
||||
.decode_chain(vec!["Hey".into(), " friend!".into()])
|
||||
.decode_chain(vec!["Hey".into(), " friend!".into(), "HHH".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["ey", "friend!"]);
|
||||
assert_eq!(res, vec!["ey", " friend!", "HH"]);
|
||||
|
||||
let decoder = Strip::new(0, 1);
|
||||
let decoder = Strip::new('y', 0, 1);
|
||||
let res = decoder
|
||||
.decode_chain(vec!["Hey".into(), " friend!".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["He", " friend"]);
|
||||
assert_eq!(res, vec!["He", " friend!"]);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user