mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
This reverts commit cdabef14c4
.
This commit is contained in:
@ -12,7 +12,7 @@ describe("wordPieceDecoder", () => {
|
||||
it("can decode arrays of strings", () => {
|
||||
expect(
|
||||
wordPieceDecoder().decode(["Hel", "##lo", "there", "my", "fr", "##iend"])
|
||||
).toEqual(["Hel", "lo", " there", " my", " fr", "iend"]);
|
||||
).toEqual("Hello there my friend");
|
||||
});
|
||||
});
|
||||
|
||||
@ -39,6 +39,6 @@ describe("ctcDecoder", () => {
|
||||
it("encodes correctly", () => {
|
||||
expect(
|
||||
ctcDecoder().decode(["<pad>", "h", "h", "e", "e", "l", "l", "<pad>", "l", "l", "o"])
|
||||
).toEqual(["h", "e", "l", "l", "o"]);
|
||||
).toEqual("hello");
|
||||
});
|
||||
});
|
||||
|
@ -14,7 +14,7 @@ pub struct Decoder {
|
||||
}
|
||||
|
||||
impl tk::Decoder for Decoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
|
||||
self.decoder
|
||||
.as_ref()
|
||||
.ok_or("Uninitialized Decoder")?
|
||||
@ -41,13 +41,7 @@ declare_types! {
|
||||
.decode(tokens)
|
||||
.map_err(|e| Error(format!("{}", e)))?;
|
||||
|
||||
let decoded = JsArray::new(&mut cx, output.len() as u32);
|
||||
for (i, token) in output.into_iter().enumerate() {
|
||||
let js_token = cx.string(token);
|
||||
decoded.set(&mut cx, i as u32, js_token)?;
|
||||
}
|
||||
|
||||
Ok(decoded.upcast())
|
||||
Ok(cx.string(output).upcast())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ impl PyDecoder {
|
||||
}
|
||||
|
||||
impl Decoder for PyDecoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
|
||||
self.decoder.decode(tokens)
|
||||
}
|
||||
}
|
||||
@ -98,7 +98,7 @@ impl PyDecoder {
|
||||
/// Returns:
|
||||
/// :obj:`str`: The decoded string
|
||||
#[text_signature = "(self, tokens)"]
|
||||
fn decode(&self, tokens: Vec<String>) -> PyResult<Vec<String>> {
|
||||
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
|
||||
ToPyResult(self.decoder.decode(tokens)).into()
|
||||
}
|
||||
}
|
||||
@ -337,12 +337,12 @@ impl CustomDecoder {
|
||||
}
|
||||
|
||||
impl Decoder for CustomDecoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
|
||||
Python::with_gil(|py| {
|
||||
let decoded = self
|
||||
.inner
|
||||
.call_method(py, "decode", (tokens,), None)?
|
||||
.extract(py)?;
|
||||
.extract::<String>(py)?;
|
||||
Ok(decoded)
|
||||
})
|
||||
}
|
||||
@ -396,7 +396,7 @@ where
|
||||
}
|
||||
|
||||
impl Decoder for PyDecoderWrapper {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
|
||||
match self {
|
||||
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode(tokens),
|
||||
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode(tokens),
|
||||
|
@ -14,7 +14,7 @@ class TestByteLevel:
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = ByteLevel()
|
||||
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == ["My name is John"]
|
||||
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == "My name is John"
|
||||
|
||||
def test_manual_reload(self):
|
||||
byte_level = ByteLevel()
|
||||
@ -34,25 +34,11 @@ class TestWordPiece:
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = WordPiece()
|
||||
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == [
|
||||
"My",
|
||||
" na",
|
||||
"me",
|
||||
" is",
|
||||
" Jo",
|
||||
"hn",
|
||||
]
|
||||
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == ["I", "'m", " Jo", "hn"]
|
||||
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == "My name is John"
|
||||
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == "I'm John"
|
||||
decoder = WordPiece(prefix="__", cleanup=False)
|
||||
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == [
|
||||
"My",
|
||||
" na",
|
||||
"me",
|
||||
" is",
|
||||
" Jo",
|
||||
"hn",
|
||||
]
|
||||
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == ["I", " 'm", " Jo", "hn"]
|
||||
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == "My name is John"
|
||||
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == "I 'm John"
|
||||
|
||||
def test_can_modify(self):
|
||||
decoder = WordPiece(prefix="$$", cleanup=False)
|
||||
@ -80,9 +66,9 @@ class TestMetaspace:
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = Metaspace()
|
||||
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == ["My", " name", " is", " John"]
|
||||
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == "My name is John"
|
||||
decoder = Metaspace(replacement="-", add_prefix_space=False)
|
||||
assert decoder.decode(["-My", "-name", "-is", "-John"]) == [" My", " name", " is", " John"]
|
||||
assert decoder.decode(["-My", "-name", "-is", "-John"]) == " My name is John"
|
||||
|
||||
def test_can_modify(self):
|
||||
decoder = Metaspace(replacement="*", add_prefix_space=False)
|
||||
@ -107,23 +93,12 @@ class TestBPEDecoder:
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = BPEDecoder()
|
||||
assert decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"]) == [
|
||||
"My ",
|
||||
"na",
|
||||
"me ",
|
||||
"is ",
|
||||
"Jo",
|
||||
"hn",
|
||||
]
|
||||
assert (
|
||||
decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"])
|
||||
== "My name is John"
|
||||
)
|
||||
decoder = BPEDecoder(suffix="_")
|
||||
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == [
|
||||
"My ",
|
||||
"na",
|
||||
"me ",
|
||||
"is ",
|
||||
"Jo",
|
||||
"hn",
|
||||
]
|
||||
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == "My name is John"
|
||||
|
||||
def test_can_modify(self):
|
||||
decoder = BPEDecoder(suffix="123")
|
||||
@ -145,13 +120,19 @@ class TestCTCDecoder:
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = CTC()
|
||||
assert decoder.decode(
|
||||
assert (
|
||||
decoder.decode(
|
||||
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
|
||||
) == ["h", "e", "l", "l", "o"]
|
||||
)
|
||||
== "hello"
|
||||
)
|
||||
decoder = CTC(pad_token="[PAD]")
|
||||
assert decoder.decode(
|
||||
assert (
|
||||
decoder.decode(
|
||||
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
|
||||
) == ["h", "e", "l", "l", "o"]
|
||||
)
|
||||
== "hello"
|
||||
)
|
||||
|
||||
def test_can_modify(self):
|
||||
decoder = CTC(pad_token="[PAD]")
|
||||
|
@ -24,15 +24,7 @@ impl Default for BPEDecoder {
|
||||
}
|
||||
|
||||
impl Decoder for BPEDecoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||
let n = tokens.len() - 1;
|
||||
Ok(tokens
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, token)| {
|
||||
let replacement = if i == n { "" } else { " " };
|
||||
token.replace(&self.suffix, replacement)
|
||||
})
|
||||
.collect())
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
Ok(tokens.join("").replace(&self.suffix, " ").trim().to_owned())
|
||||
}
|
||||
}
|
||||
|
@ -42,23 +42,16 @@ impl Default for CTC {
|
||||
}
|
||||
|
||||
impl Decoder for CTC {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||
Ok(tokens
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
let mut output = tokens
|
||||
.into_iter()
|
||||
.dedup()
|
||||
.filter_map(|token| {
|
||||
let mut replaced = token.replace(&self.pad_token, "");
|
||||
.join("")
|
||||
.replace(&self.pad_token, "");
|
||||
if self.cleanup {
|
||||
replaced =
|
||||
wordpiece::cleanup(&replaced).replace(&self.word_delimiter_token, " ");
|
||||
output = wordpiece::cleanup(output).replace(&self.word_delimiter_token, " ");
|
||||
}
|
||||
if replaced.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(replaced)
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,7 +67,7 @@ mod tests {
|
||||
.collect();
|
||||
assert_eq!(
|
||||
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||
vec!["h", "e", "l", "l", "o"]
|
||||
"hello".to_string()
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
@ -86,7 +79,7 @@ mod tests {
|
||||
.collect();
|
||||
assert_eq!(
|
||||
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||
vec!["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"]
|
||||
"hello world".to_string()
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
@ -95,11 +88,7 @@ mod tests {
|
||||
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> A | | <pad> M <pad> <pad> <pad> <pad> A <pad> <pad> N <pad> <pad> <pad> | | | <pad> <pad> <pad> <pad> S <pad> <pad> <pad> A I <pad> D D | | T T <pad> O <pad> | | T H E E | | | <pad> U U <pad> N N <pad> I <pad> <pad> V <pad> <pad> <pad> E R R <pad> <pad> <pad> S E E | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> S S <pad> <pad> <pad> <pad> I <pad> R R <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> <pad> <pad> | <pad> <pad> <pad> E X <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> S <pad> <pad> T <pad> <pad> | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
|
||||
assert_eq!(
|
||||
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||
vec![
|
||||
"A", " ", "M", "A", "N", " ", "S", "A", "I", "D", " ", "T", "O", " ", "T", "H",
|
||||
"E", " ", "U", "N", "I", "V", "E", "R", "S", "E", " ", "S", "I", "R", " ", "I",
|
||||
" ", "E", "X", "I", "S", "T", " "
|
||||
]
|
||||
"A MAN SAID TO THE UNIVERSE SIR I EXIST ".to_string()
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
@ -108,13 +97,7 @@ mod tests {
|
||||
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> I <pad> S S | | <pad> <pad> <pad> I N <pad> <pad> S <pad> T T <pad> <pad> A N C C T <pad> | | | | | <pad> <pad> <pad> <pad> P <pad> <pad> <pad> <pad> A <pad> <pad> N N N <pad> <pad> I <pad> C <pad> <pad> | | <pad> W <pad> <pad> A S <pad> | | <pad> <pad> <pad> F <pad> <pad> O L <pad> <pad> L L O O W E E D | | <pad> B <pad> <pad> <pad> Y <pad> | | | A | | <pad> S S S <pad> M M <pad> <pad> <pad> A L L <pad> <pad> <pad> <pad> L <pad> | | | <pad> <pad> <pad> <pad> S H H <pad> <pad> <pad> <pad> A R R <pad> <pad> P <pad> <pad> | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> B <pad> <pad> L L <pad> <pad> <pad> <pad> <pad> O W W <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> <pad> <pad> <pad> <pad> <pad> <pad> I G H H | | <pad> <pad> O N <pad> | | H <pad> I S S | | <pad> <pad> C H H <pad> <pad> <pad> E <pad> S S <pad> T T <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
|
||||
assert_eq!(
|
||||
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||
vec![
|
||||
"H", "I", "S", " ", "I", "N", "S", "T", "A", "N", "C", "T", " ", "P", "A", "N",
|
||||
"I", "C", " ", "W", "A", "S", " ", "F", "O", "L", "L", "O", "W", "E", "D", " ",
|
||||
"B", "Y", " ", "A", " ", "S", "M", "A", "L", "L", " ", "S", "H", "A", "R", "P",
|
||||
" ", "B", "L", "O", "W", " ", "H", "I", "G", "H", " ", "O", "N", " ", "H", "I",
|
||||
"S", " ", "C", "H", "E", "S", "T", " "
|
||||
]
|
||||
"HIS INSTANCT PANIC WAS FOLLOWED BY A SMALL SHARP BLOW HIGH ON HIS CHEST ".to_string()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ pub enum DecoderWrapper {
|
||||
}
|
||||
|
||||
impl Decoder for DecoderWrapper {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
match self {
|
||||
Self::BPE(bpe) => bpe.decode(tokens),
|
||||
Self::ByteLevel(bl) => bl.decode(tokens),
|
||||
|
@ -28,7 +28,7 @@ impl Default for WordPiece {
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn cleanup(dirty_input: &str) -> String {
|
||||
pub fn cleanup(dirty_input: String) -> String {
|
||||
dirty_input
|
||||
.replace(" .", ".")
|
||||
.replace(" ?", "?")
|
||||
@ -44,21 +44,12 @@ pub fn cleanup(dirty_input: &str) -> String {
|
||||
}
|
||||
|
||||
impl Decoder for WordPiece {
|
||||
fn decode(&self, mut tokens: Vec<String>) -> Result<Vec<String>> {
|
||||
tokens
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(i, token)| {
|
||||
if token.starts_with(&self.prefix) {
|
||||
*token = token.replacen(&self.prefix, "", 1);
|
||||
} else if i != 0 {
|
||||
*token = format!(" {}", token);
|
||||
}
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), "");
|
||||
if self.cleanup {
|
||||
*token = cleanup(token);
|
||||
output = cleanup(output);
|
||||
}
|
||||
Ok(token.to_string())
|
||||
})
|
||||
.collect::<Result<_>>()
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
@ -145,11 +145,8 @@ impl PreTokenizer for ByteLevel {
|
||||
|
||||
/// As a `Decoder`, `ByteLevel` is in charge of converting any byte-level characters to their
|
||||
/// unicode counterpart, before merging everything back into a single String.
|
||||
/// This decoder will consume the tokens and merge them in one step to alleviate
|
||||
/// the fact that single token decoded might be a byte not representable as
|
||||
/// as String.
|
||||
impl Decoder for ByteLevel {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
let toks = tokens
|
||||
.into_iter()
|
||||
.flat_map(|t| {
|
||||
@ -162,8 +159,8 @@ impl Decoder for ByteLevel {
|
||||
})
|
||||
.unwrap_or_else(|| t.as_bytes().to_vec())
|
||||
})
|
||||
.collect::<Vec<u8>>();
|
||||
Ok(vec![String::from_utf8_lossy(&toks).to_string()])
|
||||
.collect::<Vec<_>>();
|
||||
Ok(String::from_utf8_lossy(&toks).into_owned())
|
||||
}
|
||||
}
|
||||
|
||||
@ -287,6 +284,7 @@ mod tests {
|
||||
fn decoding() {
|
||||
let bytelevel = ByteLevel::default().add_prefix_space(false);
|
||||
assert_eq!(
|
||||
"Hello my friend, how is your day going?",
|
||||
bytelevel
|
||||
.decode(
|
||||
vec![
|
||||
@ -297,8 +295,7 @@ mod tests {
|
||||
.map(|s| s.into())
|
||||
.collect::<Vec<String>>()
|
||||
)
|
||||
.unwrap(),
|
||||
vec!["Hello my friend, how is your day going?"]
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
@ -350,7 +347,7 @@ mod tests {
|
||||
.iter()
|
||||
.flat_map(|(s, _, _)| s.split("").map(|t| t.into()))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(sample, bytelevel.decode(separated_tokens).unwrap().join(""));
|
||||
assert_eq!(sample, bytelevel.decode(separated_tokens).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
@ -546,7 +543,7 @@ mod tests {
|
||||
"[PA D]".into()
|
||||
])
|
||||
.unwrap(),
|
||||
vec!["Hello there dear friend! [PA D]"]
|
||||
"Hello there dear friend! [PA D]"
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -77,14 +77,12 @@ impl PreTokenizer for Metaspace {
|
||||
}
|
||||
|
||||
impl Decoder for Metaspace {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
Ok(tokens
|
||||
.iter()
|
||||
.flat_map(|t| t.chars())
|
||||
.enumerate()
|
||||
.map(|(i, token)| {
|
||||
token
|
||||
.chars()
|
||||
.flat_map(|c| {
|
||||
.filter_map(|(i, c)| {
|
||||
if c == self.replacement {
|
||||
if i == 0 && self.add_prefix_space {
|
||||
None
|
||||
@ -95,9 +93,7 @@ impl Decoder for Metaspace {
|
||||
Some(c)
|
||||
}
|
||||
})
|
||||
.collect::<String>()
|
||||
})
|
||||
.collect())
|
||||
.collect::<String>())
|
||||
}
|
||||
}
|
||||
|
||||
@ -194,6 +190,6 @@ mod tests {
|
||||
let res = decoder
|
||||
.decode(vec!["▁Hey".into(), "▁friend!".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["Hey", " friend!"])
|
||||
assert_eq!(&res, "Hey friend!")
|
||||
}
|
||||
}
|
||||
|
@ -119,9 +119,9 @@ impl dyn PostProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
/// A `Decoder` changes the raw tokens into its more readable form.
|
||||
/// A `Decoder` has the responsibility to merge the given `Vec<String>` in a `String`.
|
||||
pub trait Decoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>>;
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String>;
|
||||
}
|
||||
|
||||
/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
|
||||
@ -769,8 +769,7 @@ where
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if let Some(decoder) = &self.decoder {
|
||||
let tokens = decoder.decode(tokens)?;
|
||||
Ok(tokens.join(""))
|
||||
decoder.decode(tokens)
|
||||
} else {
|
||||
Ok(tokens.join(" "))
|
||||
}
|
||||
|
Reference in New Issue
Block a user