From ec439477865c6c8d3f9512f74075ecca45f6af61 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 4 Apr 2022 09:43:28 +0200 Subject: [PATCH] Revert "Changing `Decoder` trait to be more composable. (#938)" (#971) This reverts commit cdabef14c4f991c56b96564edcf1cfd5b96dfd65. --- bindings/node/lib/bindings/decoders.test.ts | 4 +- bindings/node/native/src/decoders.rs | 10 +-- bindings/python/src/decoders.rs | 10 +-- .../python/tests/bindings/test_decoders.py | 67 +++++++------------ tokenizers/src/decoders/bpe.rs | 12 +--- tokenizers/src/decoders/ctc.rs | 41 ++++-------- tokenizers/src/decoders/mod.rs | 2 +- tokenizers/src/decoders/wordpiece.rs | 25 +++---- tokenizers/src/pre_tokenizers/byte_level.rs | 17 ++--- tokenizers/src/pre_tokenizers/metaspace.rs | 32 ++++----- tokenizers/src/tokenizer/mod.rs | 7 +- 11 files changed, 80 insertions(+), 147 deletions(-) diff --git a/bindings/node/lib/bindings/decoders.test.ts b/bindings/node/lib/bindings/decoders.test.ts index 36da037f..b23f243f 100644 --- a/bindings/node/lib/bindings/decoders.test.ts +++ b/bindings/node/lib/bindings/decoders.test.ts @@ -12,7 +12,7 @@ describe("wordPieceDecoder", () => { it("can decode arrays of strings", () => { expect( wordPieceDecoder().decode(["Hel", "##lo", "there", "my", "fr", "##iend"]) - ).toEqual(["Hel", "lo", " there", " my", " fr", "iend"]); + ).toEqual("Hello there my friend"); }); }); @@ -39,6 +39,6 @@ describe("ctcDecoder", () => { it("encodes correctly", () => { expect( ctcDecoder().decode(["", "h", "h", "e", "e", "l", "l", "", "l", "l", "o"]) - ).toEqual(["h", "e", "l", "l", "o"]); + ).toEqual("hello"); }); }); diff --git a/bindings/node/native/src/decoders.rs b/bindings/node/native/src/decoders.rs index ea86e60a..9a01bc4d 100644 --- a/bindings/node/native/src/decoders.rs +++ b/bindings/node/native/src/decoders.rs @@ -14,7 +14,7 @@ pub struct Decoder { } impl tk::Decoder for Decoder { - fn decode(&self, tokens: Vec) -> tk::Result> { + fn decode(&self, tokens: Vec) -> tk::Result { self.decoder .as_ref() .ok_or("Uninitialized Decoder")? @@ -41,13 +41,7 @@ declare_types! { .decode(tokens) .map_err(|e| Error(format!("{}", e)))?; - let decoded = JsArray::new(&mut cx, output.len() as u32); - for (i, token) in output.into_iter().enumerate() { - let js_token = cx.string(token); - decoded.set(&mut cx, i as u32, js_token)?; - } - - Ok(decoded.upcast()) + Ok(cx.string(output).upcast()) } } } diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 92b9e213..5f15838c 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -51,7 +51,7 @@ impl PyDecoder { } impl Decoder for PyDecoder { - fn decode(&self, tokens: Vec) -> tk::Result> { + fn decode(&self, tokens: Vec) -> tk::Result { self.decoder.decode(tokens) } } @@ -98,7 +98,7 @@ impl PyDecoder { /// Returns: /// :obj:`str`: The decoded string #[text_signature = "(self, tokens)"] - fn decode(&self, tokens: Vec) -> PyResult> { + fn decode(&self, tokens: Vec) -> PyResult { ToPyResult(self.decoder.decode(tokens)).into() } } @@ -337,12 +337,12 @@ impl CustomDecoder { } impl Decoder for CustomDecoder { - fn decode(&self, tokens: Vec) -> tk::Result> { + fn decode(&self, tokens: Vec) -> tk::Result { Python::with_gil(|py| { let decoded = self .inner .call_method(py, "decode", (tokens,), None)? - .extract(py)?; + .extract::(py)?; Ok(decoded) }) } @@ -396,7 +396,7 @@ where } impl Decoder for PyDecoderWrapper { - fn decode(&self, tokens: Vec) -> tk::Result> { + fn decode(&self, tokens: Vec) -> tk::Result { match self { PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode(tokens), PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode(tokens), diff --git a/bindings/python/tests/bindings/test_decoders.py b/bindings/python/tests/bindings/test_decoders.py index 265d739b..41e7187e 100644 --- a/bindings/python/tests/bindings/test_decoders.py +++ b/bindings/python/tests/bindings/test_decoders.py @@ -14,7 +14,7 @@ class TestByteLevel: def test_decoding(self): decoder = ByteLevel() - assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == ["My name is John"] + assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == "My name is John" def test_manual_reload(self): byte_level = ByteLevel() @@ -34,25 +34,11 @@ class TestWordPiece: def test_decoding(self): decoder = WordPiece() - assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == [ - "My", - " na", - "me", - " is", - " Jo", - "hn", - ] - assert decoder.decode(["I", "'m", "Jo", "##hn"]) == ["I", "'m", " Jo", "hn"] + assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == "My name is John" + assert decoder.decode(["I", "'m", "Jo", "##hn"]) == "I'm John" decoder = WordPiece(prefix="__", cleanup=False) - assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == [ - "My", - " na", - "me", - " is", - " Jo", - "hn", - ] - assert decoder.decode(["I", "'m", "Jo", "__hn"]) == ["I", " 'm", " Jo", "hn"] + assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == "My name is John" + assert decoder.decode(["I", "'m", "Jo", "__hn"]) == "I 'm John" def test_can_modify(self): decoder = WordPiece(prefix="$$", cleanup=False) @@ -80,9 +66,9 @@ class TestMetaspace: def test_decoding(self): decoder = Metaspace() - assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == ["My", " name", " is", " John"] + assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == "My name is John" decoder = Metaspace(replacement="-", add_prefix_space=False) - assert decoder.decode(["-My", "-name", "-is", "-John"]) == [" My", " name", " is", " John"] + assert decoder.decode(["-My", "-name", "-is", "-John"]) == " My name is John" def test_can_modify(self): decoder = Metaspace(replacement="*", add_prefix_space=False) @@ -107,23 +93,12 @@ class TestBPEDecoder: def test_decoding(self): decoder = BPEDecoder() - assert decoder.decode(["My", "na", "me", "is", "Jo", "hn"]) == [ - "My ", - "na", - "me ", - "is ", - "Jo", - "hn", - ] + assert ( + decoder.decode(["My", "na", "me", "is", "Jo", "hn"]) + == "My name is John" + ) decoder = BPEDecoder(suffix="_") - assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == [ - "My ", - "na", - "me ", - "is ", - "Jo", - "hn", - ] + assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == "My name is John" def test_can_modify(self): decoder = BPEDecoder(suffix="123") @@ -145,13 +120,19 @@ class TestCTCDecoder: def test_decoding(self): decoder = CTC() - assert decoder.decode( - ["", "", "h", "e", "e", "l", "l", "", "l", "o", "o", "o", ""] - ) == ["h", "e", "l", "l", "o"] + assert ( + decoder.decode( + ["", "", "h", "e", "e", "l", "l", "", "l", "o", "o", "o", ""] + ) + == "hello" + ) decoder = CTC(pad_token="[PAD]") - assert decoder.decode( - ["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"] - ) == ["h", "e", "l", "l", "o"] + assert ( + decoder.decode( + ["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"] + ) + == "hello" + ) def test_can_modify(self): decoder = CTC(pad_token="[PAD]") diff --git a/tokenizers/src/decoders/bpe.rs b/tokenizers/src/decoders/bpe.rs index 1354c89e..1d4115cb 100644 --- a/tokenizers/src/decoders/bpe.rs +++ b/tokenizers/src/decoders/bpe.rs @@ -24,15 +24,7 @@ impl Default for BPEDecoder { } impl Decoder for BPEDecoder { - fn decode(&self, tokens: Vec) -> Result> { - let n = tokens.len() - 1; - Ok(tokens - .into_iter() - .enumerate() - .map(|(i, token)| { - let replacement = if i == n { "" } else { " " }; - token.replace(&self.suffix, replacement) - }) - .collect()) + fn decode(&self, tokens: Vec) -> Result { + Ok(tokens.join("").replace(&self.suffix, " ").trim().to_owned()) } } diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs index 06667d5f..17f7ba16 100644 --- a/tokenizers/src/decoders/ctc.rs +++ b/tokenizers/src/decoders/ctc.rs @@ -42,23 +42,16 @@ impl Default for CTC { } impl Decoder for CTC { - fn decode(&self, tokens: Vec) -> Result> { - Ok(tokens + fn decode(&self, tokens: Vec) -> Result { + let mut output = tokens .into_iter() .dedup() - .filter_map(|token| { - let mut replaced = token.replace(&self.pad_token, ""); - if self.cleanup { - replaced = - wordpiece::cleanup(&replaced).replace(&self.word_delimiter_token, " "); - } - if replaced.is_empty() { - None - } else { - Some(replaced) - } - }) - .collect()) + .join("") + .replace(&self.pad_token, ""); + if self.cleanup { + output = wordpiece::cleanup(output).replace(&self.word_delimiter_token, " "); + } + Ok(output) } } @@ -74,7 +67,7 @@ mod tests { .collect(); assert_eq!( ctc_decoder.decode(id_to_string_result).unwrap(), - vec!["h", "e", "l", "l", "o"] + "hello".to_string() ); } #[test] @@ -86,7 +79,7 @@ mod tests { .collect(); assert_eq!( ctc_decoder.decode(id_to_string_result).unwrap(), - vec!["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"] + "hello world".to_string() ); } #[test] @@ -95,11 +88,7 @@ mod tests { let id_to_string_result = " A | | M A N | | | S A I D D | | T T O | | T H E E | | | U U N N I V E R R S E E | | S S I R R | | | I | E X I S T | | ".split(' ').map(|s| s.to_string()).collect(); assert_eq!( ctc_decoder.decode(id_to_string_result).unwrap(), - vec![ - "A", " ", "M", "A", "N", " ", "S", "A", "I", "D", " ", "T", "O", " ", "T", "H", - "E", " ", "U", "N", "I", "V", "E", "R", "S", "E", " ", "S", "I", "R", " ", "I", - " ", "E", "X", "I", "S", "T", " " - ] + "A MAN SAID TO THE UNIVERSE SIR I EXIST ".to_string() ); } #[test] @@ -108,13 +97,7 @@ mod tests { let id_to_string_result = " H I S S | | I N S T T A N C C T | | | | | P A N N N I C | | W A S | | F O L L L O O W E E D | | B Y | | | A | | S S S M M A L L L | | | S H H A R R P | B L L O W W | | | H I G H H | | O N | | H I S S | | C H H E S S T T | | | ".split(' ').map(|s| s.to_string()).collect(); assert_eq!( ctc_decoder.decode(id_to_string_result).unwrap(), - vec![ - "H", "I", "S", " ", "I", "N", "S", "T", "A", "N", "C", "T", " ", "P", "A", "N", - "I", "C", " ", "W", "A", "S", " ", "F", "O", "L", "L", "O", "W", "E", "D", " ", - "B", "Y", " ", "A", " ", "S", "M", "A", "L", "L", " ", "S", "H", "A", "R", "P", - " ", "B", "L", "O", "W", " ", "H", "I", "G", "H", " ", "O", "N", " ", "H", "I", - "S", " ", "C", "H", "E", "S", "T", " " - ] + "HIS INSTANCT PANIC WAS FOLLOWED BY A SMALL SHARP BLOW HIGH ON HIS CHEST ".to_string() ); } } diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index f93ccc9a..a571ef5b 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -26,7 +26,7 @@ pub enum DecoderWrapper { } impl Decoder for DecoderWrapper { - fn decode(&self, tokens: Vec) -> Result> { + fn decode(&self, tokens: Vec) -> Result { match self { Self::BPE(bpe) => bpe.decode(tokens), Self::ByteLevel(bl) => bl.decode(tokens), diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index ab62caf1..c9b92d92 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -28,7 +28,7 @@ impl Default for WordPiece { } } } -pub fn cleanup(dirty_input: &str) -> String { +pub fn cleanup(dirty_input: String) -> String { dirty_input .replace(" .", ".") .replace(" ?", "?") @@ -44,21 +44,12 @@ pub fn cleanup(dirty_input: &str) -> String { } impl Decoder for WordPiece { - fn decode(&self, mut tokens: Vec) -> Result> { - tokens - .iter_mut() - .enumerate() - .map(|(i, token)| { - if token.starts_with(&self.prefix) { - *token = token.replacen(&self.prefix, "", 1); - } else if i != 0 { - *token = format!(" {}", token); - } - if self.cleanup { - *token = cleanup(token); - } - Ok(token.to_string()) - }) - .collect::>() + fn decode(&self, tokens: Vec) -> Result { + let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), ""); + if self.cleanup { + output = cleanup(output); + } + + Ok(output) } } diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 4bb9c07c..8f3d5fa8 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -145,11 +145,8 @@ impl PreTokenizer for ByteLevel { /// As a `Decoder`, `ByteLevel` is in charge of converting any byte-level characters to their /// unicode counterpart, before merging everything back into a single String. -/// This decoder will consume the tokens and merge them in one step to alleviate -/// the fact that single token decoded might be a byte not representable as -/// as String. impl Decoder for ByteLevel { - fn decode(&self, tokens: Vec) -> Result> { + fn decode(&self, tokens: Vec) -> Result { let toks = tokens .into_iter() .flat_map(|t| { @@ -162,8 +159,8 @@ impl Decoder for ByteLevel { }) .unwrap_or_else(|| t.as_bytes().to_vec()) }) - .collect::>(); - Ok(vec![String::from_utf8_lossy(&toks).to_string()]) + .collect::>(); + Ok(String::from_utf8_lossy(&toks).into_owned()) } } @@ -287,6 +284,7 @@ mod tests { fn decoding() { let bytelevel = ByteLevel::default().add_prefix_space(false); assert_eq!( + "Hello my friend, how is your day going?", bytelevel .decode( vec![ @@ -297,8 +295,7 @@ mod tests { .map(|s| s.into()) .collect::>() ) - .unwrap(), - vec!["Hello my friend, how is your day going?"] + .unwrap() ); } @@ -350,7 +347,7 @@ mod tests { .iter() .flat_map(|(s, _, _)| s.split("").map(|t| t.into())) .collect::>(); - assert_eq!(sample, bytelevel.decode(separated_tokens).unwrap().join("")); + assert_eq!(sample, bytelevel.decode(separated_tokens).unwrap()); } } @@ -546,7 +543,7 @@ mod tests { "[PA D]".into() ]) .unwrap(), - vec!["Hello there dear friend! [PA D]"] + "Hello there dear friend! [PA D]" ); } diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index e8338c5b..6df63df3 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -77,27 +77,23 @@ impl PreTokenizer for Metaspace { } impl Decoder for Metaspace { - fn decode(&self, tokens: Vec) -> Result> { + fn decode(&self, tokens: Vec) -> Result { Ok(tokens .iter() + .flat_map(|t| t.chars()) .enumerate() - .map(|(i, token)| { - token - .chars() - .flat_map(|c| { - if c == self.replacement { - if i == 0 && self.add_prefix_space { - None - } else { - Some(' ') - } - } else { - Some(c) - } - }) - .collect::() + .filter_map(|(i, c)| { + if c == self.replacement { + if i == 0 && self.add_prefix_space { + None + } else { + Some(' ') + } + } else { + Some(c) + } }) - .collect()) + .collect::()) } } @@ -194,6 +190,6 @@ mod tests { let res = decoder .decode(vec!["▁Hey".into(), "▁friend!".into()]) .unwrap(); - assert_eq!(res, vec!["Hey", " friend!"]) + assert_eq!(&res, "Hey friend!") } } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 36da79ef..a7672509 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -119,9 +119,9 @@ impl dyn PostProcessor { } } -/// A `Decoder` changes the raw tokens into its more readable form. +/// A `Decoder` has the responsibility to merge the given `Vec` in a `String`. pub trait Decoder { - fn decode(&self, tokens: Vec) -> Result>; + fn decode(&self, tokens: Vec) -> Result; } /// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences @@ -769,8 +769,7 @@ where .collect::>(); if let Some(decoder) = &self.decoder { - let tokens = decoder.decode(tokens)?; - Ok(tokens.join("")) + decoder.decode(tokens) } else { Ok(tokens.join(" ")) }