diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 4dc762b7..8881bc78 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -43,18 +43,24 @@ pub struct WordPiece {} #[pymethods] impl WordPiece { #[new] - #[args(kwargs="**")] + #[args(kwargs = "**")] fn new(obj: &PyRawObject, kwargs: Option<&PyDict>) -> PyResult<()> { let mut prefix = String::from("##"); + let mut cleanup = true; if let Some(kwargs) = kwargs { if let Some(p) = kwargs.get_item("prefix") { prefix = p.extract()?; } + if let Some(c) = kwargs.get_item("cleanup") { + cleanup = c.extract()?; + } } Ok(obj.init(Decoder { - decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new(prefix))), + decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new( + prefix, cleanup, + ))), })) } } diff --git a/bindings/python/tokenizers/decoders/__init__.pyi b/bindings/python/tokenizers/decoders/__init__.pyi index e101aa25..706f306c 100644 --- a/bindings/python/tokenizers/decoders/__init__.pyi +++ b/bindings/python/tokenizers/decoders/__init__.pyi @@ -22,12 +22,15 @@ class WordPiece(Decoder): """ WordPiece Decoder """ @staticmethod - def __init__(self, prefix: str = "##") -> Decoder: + def __init__(self, prefix: str = "##", cleanup: bool = True) -> Decoder: """ Instantiate a new WordPiece Decoder Args: prefix: str: The prefix to use for subwords that are not a beginning-of-word + cleanup: bool: + Whether to cleanup some tokenization artifacts. Mainly spaces before punctuation, + and some abbreviated english forms. """ pass