mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add content
to Strip decoder to allow decoding mid tokens. (#1199)
* Add `content` to Strip decoder to allow decoding mid tokens. * Stub. * Clippy.
This commit is contained in:
3
bindings/node/lib/bindings/decoders.d.ts
vendored
3
bindings/node/lib/bindings/decoders.d.ts
vendored
@ -42,10 +42,11 @@ export function fuseDecoder(): Decoder;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Instantiate a new Strip Decoder
|
* Instantiate a new Strip Decoder
|
||||||
|
* @param [content] The character to strip
|
||||||
* @param [left] The number of chars to remove from the left of each token
|
* @param [left] The number of chars to remove from the left of each token
|
||||||
* @param [right] The number of chars to remove from the right of each token
|
* @param [right] The number of chars to remove from the right of each token
|
||||||
*/
|
*/
|
||||||
export function stripDecoder(left: number, right: number): Decoder;
|
export function stripDecoder(content: string, left: number, right: number): Decoder;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Instantiate a new Metaspace
|
* Instantiate a new Metaspace
|
||||||
|
@ -65,11 +65,13 @@ describe("fuseDecoder", () => {
|
|||||||
|
|
||||||
describe("stripDecoder", () => {
|
describe("stripDecoder", () => {
|
||||||
it("accepts `undefined` as first parameter", () => {
|
it("accepts `undefined` as first parameter", () => {
|
||||||
expect(stripDecoder(0, 0)).toBeDefined();
|
expect(stripDecoder("_", 0, 0)).toBeDefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
it("can decode arrays of strings", () => {
|
it("can decode arrays of strings", () => {
|
||||||
expect(stripDecoder(1, 0).decode(["Hel", "lo"])).toEqual("elo");
|
expect(stripDecoder("_", 1, 0).decode(["_Hel", "lo", "__there"])).toEqual(
|
||||||
|
"Hello_there"
|
||||||
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -104,14 +104,15 @@ fn fuse(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
|||||||
Ok(decoder)
|
Ok(decoder)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// strip()
|
/// strip(content: char, left: usize, right: usize)
|
||||||
fn strip(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
fn strip(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||||
let left: usize = cx.extract(0)?;
|
let content: char = cx.extract(0)?;
|
||||||
let right: usize = cx.extract(1)?;
|
let left: usize = cx.extract(1)?;
|
||||||
|
let right: usize = cx.extract(2)?;
|
||||||
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
|
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
|
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
|
||||||
tk::decoders::strip::Strip::new(left, right).into(),
|
tk::decoders::strip::Strip::new(content, left, right).into(),
|
||||||
));
|
));
|
||||||
Ok(decoder)
|
Ok(decoder)
|
||||||
}
|
}
|
||||||
|
@ -226,7 +226,7 @@ class Strip(Decoder):
|
|||||||
Strips n left characters of each token, or n right characters of each token
|
Strips n left characters of each token, or n right characters of each token
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, left=0, right=0):
|
def __init__(self, content, left=0, right=0):
|
||||||
pass
|
pass
|
||||||
def decode(self, tokens):
|
def decode(self, tokens):
|
||||||
"""
|
"""
|
||||||
|
@ -261,34 +261,44 @@ impl PyFuseDec {
|
|||||||
/// Strip normalizer
|
/// Strip normalizer
|
||||||
/// Strips n left characters of each token, or n right characters of each token
|
/// Strips n left characters of each token, or n right characters of each token
|
||||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Strip")]
|
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Strip")]
|
||||||
#[pyo3(text_signature = "(self, left=0, right=0)")]
|
#[pyo3(text_signature = "(self, content, left=0, right=0)")]
|
||||||
pub struct PyStrip {}
|
pub struct PyStrip {}
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyStrip {
|
impl PyStrip {
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_left(self_: PyRef<Self>) -> usize {
|
fn get_start(self_: PyRef<Self>) -> usize {
|
||||||
getter!(self_, Strip, left)
|
getter!(self_, Strip, start)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[setter]
|
#[setter]
|
||||||
fn set_left(self_: PyRef<Self>, left: usize) {
|
fn set_start(self_: PyRef<Self>, start: usize) {
|
||||||
setter!(self_, Strip, left, left)
|
setter!(self_, Strip, start, start)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_right(self_: PyRef<Self>) -> usize {
|
fn get_stop(self_: PyRef<Self>) -> usize {
|
||||||
getter!(self_, Strip, right)
|
getter!(self_, Strip, stop)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[setter]
|
#[setter]
|
||||||
fn set_right(self_: PyRef<Self>, right: usize) {
|
fn set_stop(self_: PyRef<Self>, stop: usize) {
|
||||||
setter!(self_, Strip, right, right)
|
setter!(self_, Strip, stop, stop)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_content(self_: PyRef<Self>) -> char {
|
||||||
|
getter!(self_, Strip, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_content(self_: PyRef<Self>, content: char) {
|
||||||
|
setter!(self_, Strip, content, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[new]
|
#[new]
|
||||||
#[pyo3(signature = (left=0, right=0))]
|
#[pyo3(signature = (content=' ', left=0, right=0))]
|
||||||
fn new(left: usize, right: usize) -> (Self, PyDecoder) {
|
fn new(content: char, left: usize, right: usize) -> (Self, PyDecoder) {
|
||||||
(PyStrip {}, Strip::new(left, right).into())
|
(PyStrip {}, Strip::new(content, left, right).into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,13 +111,13 @@ class TestFuse:
|
|||||||
class TestStrip:
|
class TestStrip:
|
||||||
def test_instantiate(self):
|
def test_instantiate(self):
|
||||||
assert Strip(left=0, right=0) is not None
|
assert Strip(left=0, right=0) is not None
|
||||||
assert isinstance(Strip(left=0, right=0), Decoder)
|
assert isinstance(Strip(content="_", left=0, right=0), Decoder)
|
||||||
assert isinstance(Strip(left=0, right=0), Strip)
|
assert isinstance(Strip(content="_", left=0, right=0), Strip)
|
||||||
assert isinstance(pickle.loads(pickle.dumps(Strip(left=0, right=0))), Strip)
|
assert isinstance(pickle.loads(pickle.dumps(Strip(content="_", left=0, right=0))), Strip)
|
||||||
|
|
||||||
def test_decoding(self):
|
def test_decoding(self):
|
||||||
decoder = Strip(left=1, right=0)
|
decoder = Strip(content="_", left=1, right=0)
|
||||||
assert decoder.decode(["My", " na", "me"]) == "ynae"
|
assert decoder.decode(["_My", " na", "me", " _-", "__-"]) == "My name _-_-"
|
||||||
|
|
||||||
|
|
||||||
class TestMetaspace:
|
class TestMetaspace:
|
||||||
|
@ -9,13 +9,18 @@ use serde::{Deserialize, Serialize};
|
|||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
pub struct Strip {
|
pub struct Strip {
|
||||||
pub left: usize,
|
pub content: char,
|
||||||
pub right: usize,
|
pub start: usize,
|
||||||
|
pub stop: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Strip {
|
impl Strip {
|
||||||
pub fn new(left: usize, right: usize) -> Self {
|
pub fn new(content: char, start: usize, stop: usize) -> Self {
|
||||||
Self { left, right }
|
Self {
|
||||||
|
content,
|
||||||
|
start,
|
||||||
|
stop,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,11 +29,31 @@ impl Decoder for Strip {
|
|||||||
Ok(tokens
|
Ok(tokens
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|token| {
|
.map(|token| {
|
||||||
token
|
let chars: Vec<char> = token.chars().collect();
|
||||||
.chars()
|
|
||||||
.skip(self.left)
|
let mut start_cut = 0;
|
||||||
.take(token.len() - self.left - self.right)
|
for (i, &c) in chars.iter().enumerate().take(self.start) {
|
||||||
.collect()
|
if c == self.content {
|
||||||
|
start_cut = i + 1;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut stop_cut = chars.len();
|
||||||
|
for i in 0..self.stop {
|
||||||
|
let index = chars.len() - i - 1;
|
||||||
|
if chars[index] == self.content {
|
||||||
|
stop_cut = index;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let new_token: String = chars[start_cut..stop_cut].iter().collect();
|
||||||
|
new_token
|
||||||
})
|
})
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
@ -40,16 +65,16 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn decode() {
|
fn decode() {
|
||||||
let decoder = Strip::new(1, 0);
|
let decoder = Strip::new('H', 1, 0);
|
||||||
let res = decoder
|
let res = decoder
|
||||||
.decode_chain(vec!["Hey".into(), " friend!".into()])
|
.decode_chain(vec!["Hey".into(), " friend!".into(), "HHH".into()])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(res, vec!["ey", "friend!"]);
|
assert_eq!(res, vec!["ey", " friend!", "HH"]);
|
||||||
|
|
||||||
let decoder = Strip::new(0, 1);
|
let decoder = Strip::new('y', 0, 1);
|
||||||
let res = decoder
|
let res = decoder
|
||||||
.decode_chain(vec!["Hey".into(), " friend!".into()])
|
.decode_chain(vec!["Hey".into(), " friend!".into()])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(res, vec!["He", " friend"]);
|
assert_eq!(res, vec!["He", " friend!"]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user