WordPiece decoder with customizable prefix

This commit is contained in:
Anthony MOI
2020-01-06 20:20:42 -05:00
parent 742974f0c9
commit 4b9ae66419
3 changed files with 35 additions and 6 deletions

View File

@ -42,9 +42,17 @@ pub struct WordPiece {}
#[pymethods] #[pymethods]
impl WordPiece { impl WordPiece {
#[staticmethod] #[staticmethod]
fn new() -> PyResult<Decoder> { fn new(kwargs: Option<&PyDict>) -> PyResult<Decoder> {
let mut prefix = String::from("##");
if let Some(kwargs) = kwargs {
if let Some(p) = kwargs.get_item("prefix") {
prefix = p.extract()?;
}
}
Ok(Decoder { Ok(Decoder {
decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece)), decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new(prefix))),
}) })
} }
} }

View File

@ -24,6 +24,11 @@ class WordPiece:
""" """
@staticmethod @staticmethod
def new() -> Decoder: def new(prefix: str="##") -> Decoder:
""" Instantiate a new WordPiece Decoder """ """ Instantiate a new WordPiece Decoder
Args:
prefix: str:
The prefix to use for subwords that are not a beginning-of-word
"""
pass pass

View File

@ -1,9 +1,25 @@
use crate::tokenizer::{Decoder, Result}; use crate::tokenizer::{Decoder, Result};
pub struct WordPiece; pub struct WordPiece {
prefix: String,
}
impl WordPiece {
pub fn new(prefix: String) -> Self {
Self { prefix }
}
}
impl Default for WordPiece {
fn default() -> Self {
Self {
prefix: String::from("##"),
}
}
}
impl Decoder for WordPiece { impl Decoder for WordPiece {
fn decode(&self, tokens: Vec<String>) -> Result<String> { fn decode(&self, tokens: Vec<String>) -> Result<String> {
Ok(tokens.join(" ").replace(" ##", "")) Ok(tokens.join(" ").replace(&format!(" {}", self.prefix), ""))
} }
} }