mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
WordPiece decoder with customizable prefix
This commit is contained in:
@ -42,9 +42,17 @@ pub struct WordPiece {}
|
|||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl WordPiece {
|
impl WordPiece {
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
fn new() -> PyResult<Decoder> {
|
fn new(kwargs: Option<&PyDict>) -> PyResult<Decoder> {
|
||||||
|
let mut prefix = String::from("##");
|
||||||
|
|
||||||
|
if let Some(kwargs) = kwargs {
|
||||||
|
if let Some(p) = kwargs.get_item("prefix") {
|
||||||
|
prefix = p.extract()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(Decoder {
|
Ok(Decoder {
|
||||||
decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece)),
|
decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new(prefix))),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,11 @@ class WordPiece:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def new() -> Decoder:
|
def new(prefix: str="##") -> Decoder:
|
||||||
""" Instantiate a new WordPiece Decoder """
|
""" Instantiate a new WordPiece Decoder
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: str:
|
||||||
|
The prefix to use for subwords that are not a beginning-of-word
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -1,9 +1,25 @@
|
|||||||
use crate::tokenizer::{Decoder, Result};
|
use crate::tokenizer::{Decoder, Result};
|
||||||
|
|
||||||
pub struct WordPiece;
|
pub struct WordPiece {
|
||||||
|
prefix: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WordPiece {
|
||||||
|
pub fn new(prefix: String) -> Self {
|
||||||
|
Self { prefix }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for WordPiece {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
prefix: String::from("##"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Decoder for WordPiece {
|
impl Decoder for WordPiece {
|
||||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||||
Ok(tokens.join(" ").replace(" ##", ""))
|
Ok(tokens.join(" ").replace(&format!(" {}", self.prefix), ""))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user