WordPiece decoder with customizable prefix

This commit is contained in:
Anthony MOI
2020-01-06 20:20:42 -05:00
parent 742974f0c9
commit 4b9ae66419
3 changed files with 35 additions and 6 deletions

View File

@ -42,9 +42,17 @@ pub struct WordPiece {}
#[pymethods]
impl WordPiece {
#[staticmethod]
fn new() -> PyResult<Decoder> {
fn new(kwargs: Option<&PyDict>) -> PyResult<Decoder> {
let mut prefix = String::from("##");
if let Some(kwargs) = kwargs {
if let Some(p) = kwargs.get_item("prefix") {
prefix = p.extract()?;
}
}
Ok(Decoder {
decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece)),
decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new(prefix))),
})
}
}

View File

@ -24,6 +24,11 @@ class WordPiece:
"""
@staticmethod
def new() -> Decoder:
""" Instantiate a new WordPiece Decoder """
def new(prefix: str="##") -> Decoder:
""" Instantiate a new WordPiece Decoder
Args:
prefix: str:
The prefix to use for subwords that are not a beginning-of-word
"""
pass

View File

@ -1,9 +1,25 @@
use crate::tokenizer::{Decoder, Result};
pub struct WordPiece;
pub struct WordPiece {
prefix: String,
}
impl WordPiece {
pub fn new(prefix: String) -> Self {
Self { prefix }
}
}
impl Default for WordPiece {
fn default() -> Self {
Self {
prefix: String::from("##"),
}
}
}
impl Decoder for WordPiece {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
Ok(tokens.join(" ").replace(" ##", ""))
Ok(tokens.join(" ").replace(&format!(" {}", self.prefix), ""))
}
}