Revert "Changing Decoder trait to be more composable. (#938)" (#971)

This reverts commit cdabef14c4.
This commit is contained in:
Nicolas Patry
2022-04-04 09:43:28 +02:00
committed by GitHub
parent 23a22da18c
commit ec43947786
11 changed files with 80 additions and 147 deletions

View File

@ -12,7 +12,7 @@ describe("wordPieceDecoder", () => {
it("can decode arrays of strings", () => {
expect(
wordPieceDecoder().decode(["Hel", "##lo", "there", "my", "fr", "##iend"])
).toEqual(["Hel", "lo", " there", " my", " fr", "iend"]);
).toEqual("Hello there my friend");
});
});
@ -39,6 +39,6 @@ describe("ctcDecoder", () => {
it("encodes correctly", () => {
expect(
ctcDecoder().decode(["<pad>", "h", "h", "e", "e", "l", "l", "<pad>", "l", "l", "o"])
).toEqual(["h", "e", "l", "l", "o"]);
).toEqual("hello");
});
});

View File

@ -14,7 +14,7 @@ pub struct Decoder {
}
impl tk::Decoder for Decoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
self.decoder
.as_ref()
.ok_or("Uninitialized Decoder")?
@ -41,13 +41,7 @@ declare_types! {
.decode(tokens)
.map_err(|e| Error(format!("{}", e)))?;
let decoded = JsArray::new(&mut cx, output.len() as u32);
for (i, token) in output.into_iter().enumerate() {
let js_token = cx.string(token);
decoded.set(&mut cx, i as u32, js_token)?;
}
Ok(decoded.upcast())
Ok(cx.string(output).upcast())
}
}
}

View File

@ -51,7 +51,7 @@ impl PyDecoder {
}
impl Decoder for PyDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
self.decoder.decode(tokens)
}
}
@ -98,7 +98,7 @@ impl PyDecoder {
/// Returns:
/// :obj:`str`: The decoded string
#[text_signature = "(self, tokens)"]
fn decode(&self, tokens: Vec<String>) -> PyResult<Vec<String>> {
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
ToPyResult(self.decoder.decode(tokens)).into()
}
}
@ -337,12 +337,12 @@ impl CustomDecoder {
}
impl Decoder for CustomDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode", (tokens,), None)?
.extract(py)?;
.extract::<String>(py)?;
Ok(decoded)
})
}
@ -396,7 +396,7 @@ where
}
impl Decoder for PyDecoderWrapper {
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
match self {
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode(tokens),
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode(tokens),

View File

@ -14,7 +14,7 @@ class TestByteLevel:
def test_decoding(self):
decoder = ByteLevel()
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == ["My name is John"]
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == "My name is John"
def test_manual_reload(self):
byte_level = ByteLevel()
@ -34,25 +34,11 @@ class TestWordPiece:
def test_decoding(self):
decoder = WordPiece()
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == [
"My",
" na",
"me",
" is",
" Jo",
"hn",
]
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == ["I", "'m", " Jo", "hn"]
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == "My name is John"
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == "I'm John"
decoder = WordPiece(prefix="__", cleanup=False)
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == [
"My",
" na",
"me",
" is",
" Jo",
"hn",
]
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == ["I", " 'm", " Jo", "hn"]
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == "My name is John"
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == "I 'm John"
def test_can_modify(self):
decoder = WordPiece(prefix="$$", cleanup=False)
@ -80,9 +66,9 @@ class TestMetaspace:
def test_decoding(self):
decoder = Metaspace()
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == ["My", " name", " is", " John"]
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == "My name is John"
decoder = Metaspace(replacement="-", add_prefix_space=False)
assert decoder.decode(["-My", "-name", "-is", "-John"]) == [" My", " name", " is", " John"]
assert decoder.decode(["-My", "-name", "-is", "-John"]) == " My name is John"
def test_can_modify(self):
decoder = Metaspace(replacement="*", add_prefix_space=False)
@ -107,23 +93,12 @@ class TestBPEDecoder:
def test_decoding(self):
decoder = BPEDecoder()
assert decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"]) == [
"My ",
"na",
"me ",
"is ",
"Jo",
"hn",
]
assert (
decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"])
== "My name is John"
)
decoder = BPEDecoder(suffix="_")
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == [
"My ",
"na",
"me ",
"is ",
"Jo",
"hn",
]
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == "My name is John"
def test_can_modify(self):
decoder = BPEDecoder(suffix="123")
@ -145,13 +120,19 @@ class TestCTCDecoder:
def test_decoding(self):
decoder = CTC()
assert decoder.decode(
assert (
decoder.decode(
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
) == ["h", "e", "l", "l", "o"]
)
== "hello"
)
decoder = CTC(pad_token="[PAD]")
assert decoder.decode(
assert (
decoder.decode(
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
) == ["h", "e", "l", "l", "o"]
)
== "hello"
)
def test_can_modify(self):
decoder = CTC(pad_token="[PAD]")

View File

@ -24,15 +24,7 @@ impl Default for BPEDecoder {
}
impl Decoder for BPEDecoder {
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
let n = tokens.len() - 1;
Ok(tokens
.into_iter()
.enumerate()
.map(|(i, token)| {
let replacement = if i == n { "" } else { " " };
token.replace(&self.suffix, replacement)
})
.collect())
fn decode(&self, tokens: Vec<String>) -> Result<String> {
Ok(tokens.join("").replace(&self.suffix, " ").trim().to_owned())
}
}

View File

@ -42,23 +42,16 @@ impl Default for CTC {
}
impl Decoder for CTC {
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
Ok(tokens
fn decode(&self, tokens: Vec<String>) -> Result<String> {
let mut output = tokens
.into_iter()
.dedup()
.filter_map(|token| {
let mut replaced = token.replace(&self.pad_token, "");
.join("")
.replace(&self.pad_token, "");
if self.cleanup {
replaced =
wordpiece::cleanup(&replaced).replace(&self.word_delimiter_token, " ");
output = wordpiece::cleanup(output).replace(&self.word_delimiter_token, " ");
}
if replaced.is_empty() {
None
} else {
Some(replaced)
}
})
.collect())
Ok(output)
}
}
@ -74,7 +67,7 @@ mod tests {
.collect();
assert_eq!(
ctc_decoder.decode(id_to_string_result).unwrap(),
vec!["h", "e", "l", "l", "o"]
"hello".to_string()
);
}
#[test]
@ -86,7 +79,7 @@ mod tests {
.collect();
assert_eq!(
ctc_decoder.decode(id_to_string_result).unwrap(),
vec!["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"]
"hello world".to_string()
);
}
#[test]
@ -95,11 +88,7 @@ mod tests {
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> A | | <pad> M <pad> <pad> <pad> <pad> A <pad> <pad> N <pad> <pad> <pad> | | | <pad> <pad> <pad> <pad> S <pad> <pad> <pad> A I <pad> D D | | T T <pad> O <pad> | | T H E E | | | <pad> U U <pad> N N <pad> I <pad> <pad> V <pad> <pad> <pad> E R R <pad> <pad> <pad> S E E | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> S S <pad> <pad> <pad> <pad> I <pad> R R <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> <pad> <pad> | <pad> <pad> <pad> E X <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> S <pad> <pad> T <pad> <pad> | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
assert_eq!(
ctc_decoder.decode(id_to_string_result).unwrap(),
vec![
"A", " ", "M", "A", "N", " ", "S", "A", "I", "D", " ", "T", "O", " ", "T", "H",
"E", " ", "U", "N", "I", "V", "E", "R", "S", "E", " ", "S", "I", "R", " ", "I",
" ", "E", "X", "I", "S", "T", " "
]
"A MAN SAID TO THE UNIVERSE SIR I EXIST ".to_string()
);
}
#[test]
@ -108,13 +97,7 @@ mod tests {
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> I <pad> S S | | <pad> <pad> <pad> I N <pad> <pad> S <pad> T T <pad> <pad> A N C C T <pad> | | | | | <pad> <pad> <pad> <pad> P <pad> <pad> <pad> <pad> A <pad> <pad> N N N <pad> <pad> I <pad> C <pad> <pad> | | <pad> W <pad> <pad> A S <pad> | | <pad> <pad> <pad> F <pad> <pad> O L <pad> <pad> L L O O W E E D | | <pad> B <pad> <pad> <pad> Y <pad> | | | A | | <pad> S S S <pad> M M <pad> <pad> <pad> A L L <pad> <pad> <pad> <pad> L <pad> | | | <pad> <pad> <pad> <pad> S H H <pad> <pad> <pad> <pad> A R R <pad> <pad> P <pad> <pad> | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> B <pad> <pad> L L <pad> <pad> <pad> <pad> <pad> O W W <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> <pad> <pad> <pad> <pad> <pad> <pad> I G H H | | <pad> <pad> O N <pad> | | H <pad> I S S | | <pad> <pad> C H H <pad> <pad> <pad> E <pad> S S <pad> T T <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
assert_eq!(
ctc_decoder.decode(id_to_string_result).unwrap(),
vec![
"H", "I", "S", " ", "I", "N", "S", "T", "A", "N", "C", "T", " ", "P", "A", "N",
"I", "C", " ", "W", "A", "S", " ", "F", "O", "L", "L", "O", "W", "E", "D", " ",
"B", "Y", " ", "A", " ", "S", "M", "A", "L", "L", " ", "S", "H", "A", "R", "P",
" ", "B", "L", "O", "W", " ", "H", "I", "G", "H", " ", "O", "N", " ", "H", "I",
"S", " ", "C", "H", "E", "S", "T", " "
]
"HIS INSTANCT PANIC WAS FOLLOWED BY A SMALL SHARP BLOW HIGH ON HIS CHEST ".to_string()
);
}
}

View File

@ -26,7 +26,7 @@ pub enum DecoderWrapper {
}
impl Decoder for DecoderWrapper {
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
match self {
Self::BPE(bpe) => bpe.decode(tokens),
Self::ByteLevel(bl) => bl.decode(tokens),

View File

@ -28,7 +28,7 @@ impl Default for WordPiece {
}
}
}
pub fn cleanup(dirty_input: &str) -> String {
pub fn cleanup(dirty_input: String) -> String {
dirty_input
.replace(" .", ".")
.replace(" ?", "?")
@ -44,21 +44,12 @@ pub fn cleanup(dirty_input: &str) -> String {
}
impl Decoder for WordPiece {
fn decode(&self, mut tokens: Vec<String>) -> Result<Vec<String>> {
tokens
.iter_mut()
.enumerate()
.map(|(i, token)| {
if token.starts_with(&self.prefix) {
*token = token.replacen(&self.prefix, "", 1);
} else if i != 0 {
*token = format!(" {}", token);
}
fn decode(&self, tokens: Vec<String>) -> Result<String> {
let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), "");
if self.cleanup {
*token = cleanup(token);
output = cleanup(output);
}
Ok(token.to_string())
})
.collect::<Result<_>>()
Ok(output)
}
}

View File

@ -145,11 +145,8 @@ impl PreTokenizer for ByteLevel {
/// As a `Decoder`, `ByteLevel` is in charge of converting any byte-level characters to their
/// unicode counterpart, before merging everything back into a single String.
/// This decoder will consume the tokens and merge them in one step to alleviate
/// the fact that single token decoded might be a byte not representable as
/// as String.
impl Decoder for ByteLevel {
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
let toks = tokens
.into_iter()
.flat_map(|t| {
@ -162,8 +159,8 @@ impl Decoder for ByteLevel {
})
.unwrap_or_else(|| t.as_bytes().to_vec())
})
.collect::<Vec<u8>>();
Ok(vec![String::from_utf8_lossy(&toks).to_string()])
.collect::<Vec<_>>();
Ok(String::from_utf8_lossy(&toks).into_owned())
}
}
@ -287,6 +284,7 @@ mod tests {
fn decoding() {
let bytelevel = ByteLevel::default().add_prefix_space(false);
assert_eq!(
"Hello my friend, how is your day going?",
bytelevel
.decode(
vec![
@ -297,8 +295,7 @@ mod tests {
.map(|s| s.into())
.collect::<Vec<String>>()
)
.unwrap(),
vec!["Hello my friend, how is your day going?"]
.unwrap()
);
}
@ -350,7 +347,7 @@ mod tests {
.iter()
.flat_map(|(s, _, _)| s.split("").map(|t| t.into()))
.collect::<Vec<_>>();
assert_eq!(sample, bytelevel.decode(separated_tokens).unwrap().join(""));
assert_eq!(sample, bytelevel.decode(separated_tokens).unwrap());
}
}
@ -546,7 +543,7 @@ mod tests {
"[PA D]".into()
])
.unwrap(),
vec!["Hello there dear friend! [PA D]"]
"Hello there dear friend! [PA D]"
);
}

View File

@ -77,14 +77,12 @@ impl PreTokenizer for Metaspace {
}
impl Decoder for Metaspace {
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
Ok(tokens
.iter()
.flat_map(|t| t.chars())
.enumerate()
.map(|(i, token)| {
token
.chars()
.flat_map(|c| {
.filter_map(|(i, c)| {
if c == self.replacement {
if i == 0 && self.add_prefix_space {
None
@ -95,9 +93,7 @@ impl Decoder for Metaspace {
Some(c)
}
})
.collect::<String>()
})
.collect())
.collect::<String>())
}
}
@ -194,6 +190,6 @@ mod tests {
let res = decoder
.decode(vec!["▁Hey".into(), "▁friend!".into()])
.unwrap();
assert_eq!(res, vec!["Hey", " friend!"])
assert_eq!(&res, "Hey friend!")
}
}

View File

@ -119,9 +119,9 @@ impl dyn PostProcessor {
}
}
/// A `Decoder` changes the raw tokens into its more readable form.
/// A `Decoder` has the responsibility to merge the given `Vec<String>` in a `String`.
pub trait Decoder {
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>>;
fn decode(&self, tokens: Vec<String>) -> Result<String>;
}
/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
@ -769,8 +769,7 @@ where
.collect::<Vec<_>>();
if let Some(decoder) = &self.decoder {
let tokens = decoder.decode(tokens)?;
Ok(tokens.join(""))
decoder.decode(tokens)
} else {
Ok(tokens.join(" "))
}