Python - add Metaspace decoder

This commit is contained in:
Anthony MOI
2020-01-07 18:40:18 -05:00
parent 43acdcfacf
commit cbdd2cf423
6 changed files with 98 additions and 5 deletions

View File

@ -57,6 +57,41 @@ impl WordPiece {
} }
} }
#[pyclass]
pub struct Metaspace {}
#[pymethods]
impl Metaspace {
#[staticmethod]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<Decoder> {
let mut replacement = '▁';
let mut add_prefix_space = true;
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
match key {
"replacement" => {
let s: &str = value.extract()?;
replacement = s.chars().nth(0).ok_or(exceptions::Exception::py_err(
"replacement must be a character",
))?;
}
"add_prefix_space" => add_prefix_space = value.extract()?,
_ => println!("Ignored unknown kwarg option {}", key),
}
}
}
Ok(Decoder {
decoder: Container::Owned(Box::new(tk::decoder::metaspace::Metaspace::new(
replacement,
add_prefix_space,
))),
})
}
}
struct PyDecoder { struct PyDecoder {
class: PyObject, class: PyObject,
} }

View File

@ -49,6 +49,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<decoders::Decoder>()?; m.add_class::<decoders::Decoder>()?;
m.add_class::<decoders::ByteLevel>()?; m.add_class::<decoders::ByteLevel>()?;
m.add_class::<decoders::WordPiece>()?; m.add_class::<decoders::WordPiece>()?;
m.add_class::<decoders::Metaspace>()?;
Ok(()) Ok(())
} }

View File

@ -3,3 +3,4 @@ from .. import decoders
Decoder = decoders.Decoder Decoder = decoders.Decoder
ByteLevel = decoders.ByteLevel ByteLevel = decoders.ByteLevel
WordPiece = decoders.WordPiece WordPiece = decoders.WordPiece
Metaspace = decoders.Metaspace

View File

@ -31,3 +31,22 @@ class WordPiece:
The prefix to use for subwords that are not a beginning-of-word The prefix to use for subwords that are not a beginning-of-word
""" """
pass pass
class Metaspace:
""" Metaspace decoder """
@staticmethod
def new(replacement: str="",
add_prefix_space: bool=True) -> Decoder:
""" Instantiate a new Metaspace
Args:
replacement: str:
The replacement character. Must be exactly one character. By default we
use the `▁` (U+2581) meta symbol (Same as in SentencePiece).
add_prefix_space: boolean:
Whether to add a space to the first word if there isn't already one. This
lets us treat `hello` exactly like `say hello`.
"""
pass

View File

@ -1,3 +1,5 @@
// Re-export this as a decoder
pub use super::pre_tokenizers::byte_level;
pub mod wordpiece; pub mod wordpiece;
// Re-export these as decoders
pub use super::pre_tokenizers::byte_level;
pub use super::pre_tokenizers::metaspace;

View File

@ -1,5 +1,7 @@
use crate::tokenizer::{Offsets, PreTokenizer, Result}; use crate::tokenizer::{Decoder, Offsets, PreTokenizer, Result};
/// Replaces all the whitespaces by the provided meta character and then
/// splits on this character
pub struct Metaspace { pub struct Metaspace {
replacement: char, replacement: char,
add_prefix_space: bool, add_prefix_space: bool,
@ -52,13 +54,37 @@ impl PreTokenizer for Metaspace {
} }
} }
impl Decoder for Metaspace {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
Ok(tokens
.into_iter()
.map(|t| t.chars().collect::<Vec<_>>())
.flatten()
.enumerate()
.map(|(i, c)| {
if c == self.replacement {
if i == 0 && self.add_prefix_space {
None
} else {
Some(' ')
}
} else {
Some(c)
}
})
.filter(|c| c.is_some())
.map(|c| c.unwrap())
.collect::<String>())
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn basic() { fn basic() {
let pretok = Metaspace::default(); let pretok = Metaspace::new('▁', true);
let res = pretok.pre_tokenize("Hey friend!").unwrap(); let res = pretok.pre_tokenize("Hey friend!").unwrap();
assert_eq!( assert_eq!(
&res, &res,
@ -68,7 +94,7 @@ mod tests {
#[test] #[test]
fn multiple_spaces() { fn multiple_spaces() {
let pretok = Metaspace::default(); let pretok = Metaspace::new('▁', true);
let res = pretok.pre_tokenize("Hey friend!").unwrap(); let res = pretok.pre_tokenize("Hey friend!").unwrap();
assert_eq!( assert_eq!(
&res, &res,
@ -80,4 +106,13 @@ mod tests {
] ]
); );
} }
#[test]
fn decode() {
let decoder = Metaspace::new('▁', true);
let res = decoder
.decode(vec!["▁Hey".into(), "▁friend!".into()])
.unwrap();
assert_eq!(&res, "Hey friend!")
}
} }