diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index fc3c81a9..28b2d585 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -57,6 +57,41 @@ impl WordPiece { } } +#[pyclass] +pub struct Metaspace {} +#[pymethods] +impl Metaspace { + #[staticmethod] + #[args(kwargs = "**")] + fn new(kwargs: Option<&PyDict>) -> PyResult { + let mut replacement = '▁'; + let mut add_prefix_space = true; + + if let Some(kwargs) = kwargs { + for (key, value) in kwargs { + let key: &str = key.extract()?; + match key { + "replacement" => { + let s: &str = value.extract()?; + replacement = s.chars().nth(0).ok_or(exceptions::Exception::py_err( + "replacement must be a character", + ))?; + } + "add_prefix_space" => add_prefix_space = value.extract()?, + _ => println!("Ignored unknown kwarg option {}", key), + } + } + } + + Ok(Decoder { + decoder: Container::Owned(Box::new(tk::decoder::metaspace::Metaspace::new( + replacement, + add_prefix_space, + ))), + }) + } +} + struct PyDecoder { class: PyObject, } diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 9786880c..a21977da 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -49,6 +49,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/bindings/python/tokenizers/decoders/__init__.py b/bindings/python/tokenizers/decoders/__init__.py index fac8392a..12e5eb99 100644 --- a/bindings/python/tokenizers/decoders/__init__.py +++ b/bindings/python/tokenizers/decoders/__init__.py @@ -3,3 +3,4 @@ from .. import decoders Decoder = decoders.Decoder ByteLevel = decoders.ByteLevel WordPiece = decoders.WordPiece +Metaspace = decoders.Metaspace diff --git a/bindings/python/tokenizers/decoders/__init__.pyi b/bindings/python/tokenizers/decoders/__init__.pyi index 21f941f3..42627b0d 100644 --- a/bindings/python/tokenizers/decoders/__init__.pyi +++ b/bindings/python/tokenizers/decoders/__init__.pyi @@ -31,3 +31,22 @@ class WordPiece: The prefix to use for subwords that are not a beginning-of-word """ pass + +class Metaspace: + """ Metaspace decoder """ + + @staticmethod + def new(replacement: str="▁", + add_prefix_space: bool=True) -> Decoder: + """ Instantiate a new Metaspace + + Args: + replacement: str: + The replacement character. Must be exactly one character. By default we + use the `▁` (U+2581) meta symbol (Same as in SentencePiece). + + add_prefix_space: boolean: + Whether to add a space to the first word if there isn't already one. This + lets us treat `hello` exactly like `say hello`. + """ + pass diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 47e13402..731444bb 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -1,3 +1,5 @@ -// Re-export this as a decoder -pub use super::pre_tokenizers::byte_level; pub mod wordpiece; + +// Re-export these as decoders +pub use super::pre_tokenizers::byte_level; +pub use super::pre_tokenizers::metaspace; diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 3cfa89d2..c0d7a89b 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -1,5 +1,7 @@ -use crate::tokenizer::{Offsets, PreTokenizer, Result}; +use crate::tokenizer::{Decoder, Offsets, PreTokenizer, Result}; +/// Replaces all the whitespaces by the provided meta character and then +/// splits on this character pub struct Metaspace { replacement: char, add_prefix_space: bool, @@ -52,13 +54,37 @@ impl PreTokenizer for Metaspace { } } +impl Decoder for Metaspace { + fn decode(&self, tokens: Vec) -> Result { + Ok(tokens + .into_iter() + .map(|t| t.chars().collect::>()) + .flatten() + .enumerate() + .map(|(i, c)| { + if c == self.replacement { + if i == 0 && self.add_prefix_space { + None + } else { + Some(' ') + } + } else { + Some(c) + } + }) + .filter(|c| c.is_some()) + .map(|c| c.unwrap()) + .collect::()) + } +} + #[cfg(test)] mod tests { use super::*; #[test] fn basic() { - let pretok = Metaspace::default(); + let pretok = Metaspace::new('▁', true); let res = pretok.pre_tokenize("Hey friend!").unwrap(); assert_eq!( &res, @@ -68,7 +94,7 @@ mod tests { #[test] fn multiple_spaces() { - let pretok = Metaspace::default(); + let pretok = Metaspace::new('▁', true); let res = pretok.pre_tokenize("Hey friend!").unwrap(); assert_eq!( &res, @@ -80,4 +106,13 @@ mod tests { ] ); } + + #[test] + fn decode() { + let decoder = Metaspace::new('▁', true); + let res = decoder + .decode(vec!["▁Hey".into(), "▁friend!".into()]) + .unwrap(); + assert_eq!(&res, "Hey friend!") + } }