Files
tokenizers/bindings/python/src/decoders.rs
Nicolas Patry 2e2e7558f7 Add CTC Decoder for Wave2Vec models (#693)
* Rust - add a CTCDecoder as a seperate mod

* Adding bindings to Node + Python.

* Clippy update.

* Stub.

* Fixing roberta.json URLs.

* Moving test files to hf.co.

* Update cargo check and clippy to 1.52.

* Inner ':' actually is used for domains in sphinx.

Making `domain` work correctly was just too much work so I went the easy
way and have global roles for the custom rust extension.

* Update struct naming and docs

* Update changelog

Co-authored-by: Thomaub <github.thomaub@gmail.com>
Co-authored-by: Anthony MOI <m.anthony.moi@gmail.com>
2021-05-20 09:30:09 -04:00

453 lines
14 KiB
Rust

use std::sync::{Arc, RwLock};
use crate::utils::PyChar;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::decoders::bpe::BPEDecoder;
use tk::decoders::byte_level::ByteLevel;
use tk::decoders::ctc::CTC;
use tk::decoders::metaspace::Metaspace;
use tk::decoders::wordpiece::WordPiece;
use tk::decoders::DecoderWrapper;
use tk::Decoder;
use tokenizers as tk;
use super::error::ToPyResult;
/// Base class for all decoders
///
/// This class is not supposed to be instantiated directly. Instead, any implementation of
/// a Decoder will return an instance of this class when instantiated.
#[pyclass(dict, module = "tokenizers.decoders", name=Decoder)]
#[derive(Clone, Deserialize, Serialize)]
pub struct PyDecoder {
#[serde(flatten)]
pub(crate) decoder: PyDecoderWrapper,
}
impl PyDecoder {
pub(crate) fn new(decoder: PyDecoderWrapper) -> Self {
PyDecoder { decoder }
}
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
Ok(match &self.decoder {
PyDecoderWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyDecoderWrapper::Wrapped(inner) => match &*inner.as_ref().read().unwrap() {
DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?.into_py(py),
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
},
})
}
}
impl Decoder for PyDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
self.decoder.decode(tokens)
}
}
#[pymethods]
impl PyDecoder {
#[staticmethod]
fn custom(decoder: PyObject) -> Self {
let decoder = PyDecoderWrapper::Custom(Arc::new(RwLock::new(CustomDecoder::new(decoder))));
PyDecoder::new(decoder)
}
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.decoder).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to pickle Decoder: {}",
e
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
}
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.decoder = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Decoder: {}",
e
))
})?;
Ok(())
}
Err(e) => Err(e),
}
}
/// Decode the given list of tokens to a final string
///
/// Args:
/// tokens (:obj:`List[str]`):
/// The list of tokens to decode
///
/// Returns:
/// :obj:`str`: The decoded string
#[text_signature = "(self, tokens)"]
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
ToPyResult(self.decoder.decode(tokens)).into()
}
}
macro_rules! getter {
($self: ident, $variant: ident, $($name: tt)+) => {{
let super_ = $self.as_ref();
if let PyDecoderWrapper::Wrapped(ref wrap) = super_.decoder {
if let DecoderWrapper::$variant(ref dec) = *wrap.read().unwrap() {
dec.$($name)+
} else {
unreachable!()
}
} else {
unreachable!()
}
}};
}
macro_rules! setter {
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
let super_ = $self.as_ref();
if let PyDecoderWrapper::Wrapped(ref wrap) = super_.decoder {
if let DecoderWrapper::$variant(ref mut dec) = *wrap.write().unwrap() {
dec.$name = $value;
}
}
}};
($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
let super_ = $self.as_ref();
if let PyDecoderWrapper::Wrapped(ref wrap) = super_.decoder {
if let DecoderWrapper::$variant(ref mut dec) = *wrap.write().unwrap() {
dec.$name($value);
}
}
}};
}
/// ByteLevel Decoder
///
/// This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.ByteLevel`
/// :class:`~tokenizers.pre_tokenizers.PreTokenizer`.
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=ByteLevel)]
#[text_signature = "(self)"]
pub struct PyByteLevelDec {}
#[pymethods]
impl PyByteLevelDec {
#[new]
fn new() -> (Self, PyDecoder) {
(PyByteLevelDec {}, ByteLevel::default().into())
}
}
/// WordPiece Decoder
///
/// Args:
/// prefix (:obj:`str`, `optional`, defaults to :obj:`##`):
/// The prefix to use for subwords that are not a beginning-of-word
///
/// cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`):
/// Whether to cleanup some tokenization artifacts. Mainly spaces before punctuation,
/// and some abbreviated english forms.
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=WordPiece)]
#[text_signature = "(self, prefix=\"##\", cleanup=True)"]
pub struct PyWordPieceDec {}
#[pymethods]
impl PyWordPieceDec {
#[getter]
fn get_prefix(self_: PyRef<Self>) -> String {
getter!(self_, WordPiece, prefix.clone())
}
#[setter]
fn set_prefix(self_: PyRef<Self>, prefix: String) {
setter!(self_, WordPiece, prefix, prefix);
}
#[getter]
fn get_cleanup(self_: PyRef<Self>) -> bool {
getter!(self_, WordPiece, cleanup)
}
#[setter]
fn set_cleanup(self_: PyRef<Self>, cleanup: bool) {
setter!(self_, WordPiece, cleanup, cleanup);
}
#[new]
#[args(prefix = "String::from(\"##\")", cleanup = "true")]
fn new(prefix: String, cleanup: bool) -> (Self, PyDecoder) {
(PyWordPieceDec {}, WordPiece::new(prefix, cleanup).into())
}
}
/// Metaspace Decoder
///
/// Args:
/// replacement (:obj:`str`, `optional`, defaults to :obj:`▁`):
/// The replacement character. Must be exactly one character. By default we
/// use the `▁` (U+2581) meta symbol (Same as in SentencePiece).
///
/// add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`True`):
/// Whether to add a space to the first word if there isn't already one. This
/// lets us treat `hello` exactly like `say hello`.
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=Metaspace)]
#[text_signature = "(self, replacement = \"\", add_prefix_space = True)"]
pub struct PyMetaspaceDec {}
#[pymethods]
impl PyMetaspaceDec {
#[getter]
fn get_replacement(self_: PyRef<Self>) -> String {
getter!(self_, Metaspace, get_replacement().to_string())
}
#[setter]
fn set_replacement(self_: PyRef<Self>, replacement: PyChar) {
setter!(self_, Metaspace, @set_replacement, replacement.0);
}
#[getter]
fn get_add_prefix_space(self_: PyRef<Self>) -> bool {
getter!(self_, Metaspace, add_prefix_space)
}
#[setter]
fn set_add_prefix_space(self_: PyRef<Self>, add_prefix_space: bool) {
setter!(self_, Metaspace, add_prefix_space, add_prefix_space);
}
#[new]
#[args(replacement = "PyChar('▁')", add_prefix_space = "true")]
fn new(replacement: PyChar, add_prefix_space: bool) -> (Self, PyDecoder) {
(
PyMetaspaceDec {},
Metaspace::new(replacement.0, add_prefix_space).into(),
)
}
}
/// BPEDecoder Decoder
///
/// Args:
/// suffix (:obj:`str`, `optional`, defaults to :obj:`</w>`):
/// The suffix that was used to caracterize an end-of-word. This suffix will
/// be replaced by whitespaces during the decoding
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=BPEDecoder)]
#[text_signature = "(self, suffix=\"</w>\")"]
pub struct PyBPEDecoder {}
#[pymethods]
impl PyBPEDecoder {
#[getter]
fn get_suffix(self_: PyRef<Self>) -> String {
getter!(self_, BPE, suffix.clone())
}
#[setter]
fn set_suffix(self_: PyRef<Self>, suffix: String) {
setter!(self_, BPE, suffix, suffix);
}
#[new]
#[args(suffix = "String::from(\"</w>\")")]
fn new(suffix: String) -> (Self, PyDecoder) {
(PyBPEDecoder {}, BPEDecoder::new(suffix).into())
}
}
/// CTC Decoder
///
/// Args:
/// pad_token (:obj:`str`, `optional`, defaults to :obj:`<pad>`):
/// The pad token used by CTC to delimit a new token.
/// word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`|`):
/// The word delimiter token. It will be replaced by a <space>
/// cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`):
/// Whether to cleanup some tokenization artifacts.
/// Mainly spaces before punctuation, and some abbreviated english forms.
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=CTC)]
#[text_signature = "(self, pad_token=\"<pad>\", word_delimiter_token=\"|\", cleanup=True)"]
pub struct PyCTCDecoder {}
#[pymethods]
impl PyCTCDecoder {
#[getter]
fn get_pad_token(self_: PyRef<Self>) -> String {
getter!(self_, CTC, pad_token.clone())
}
#[setter]
fn set_pad_token(self_: PyRef<Self>, pad_token: String) {
setter!(self_, CTC, pad_token, pad_token);
}
#[getter]
fn get_word_delimiter_token(self_: PyRef<Self>) -> String {
getter!(self_, CTC, word_delimiter_token.clone())
}
#[setter]
fn set_word_delimiter_token(self_: PyRef<Self>, word_delimiter_token: String) {
setter!(self_, CTC, word_delimiter_token, word_delimiter_token);
}
#[getter]
fn get_cleanup(self_: PyRef<Self>) -> bool {
getter!(self_, CTC, cleanup)
}
#[setter]
fn set_cleanup(self_: PyRef<Self>, cleanup: bool) {
setter!(self_, CTC, cleanup, cleanup);
}
#[new]
#[args(
pad_token = "String::from(\"<pad>\")",
word_delimiter_token = "String::from(\"|\")",
cleanup = "true"
)]
fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> (Self, PyDecoder) {
(
PyCTCDecoder {},
CTC::new(pad_token, word_delimiter_token, cleanup).into(),
)
}
}
#[derive(Clone)]
pub(crate) struct CustomDecoder {
inner: PyObject,
}
impl CustomDecoder {
pub(crate) fn new(inner: PyObject) -> Self {
CustomDecoder { inner }
}
}
impl Decoder for CustomDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode", (tokens,), None)?
.extract::<String>(py)?;
Ok(decoded)
})
}
}
impl Serialize for CustomDecoder {
fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
Err(serde::ser::Error::custom(
"Custom PyDecoder cannot be serialized",
))
}
}
impl<'de> Deserialize<'de> for CustomDecoder {
fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Err(D::Error::custom("PyDecoder cannot be deserialized"))
}
}
#[derive(Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub(crate) enum PyDecoderWrapper {
Custom(Arc<RwLock<CustomDecoder>>),
Wrapped(Arc<RwLock<DecoderWrapper>>),
}
impl<I> From<I> for PyDecoderWrapper
where
I: Into<DecoderWrapper>,
{
fn from(norm: I) -> Self {
PyDecoderWrapper::Wrapped(Arc::new(RwLock::new(norm.into())))
}
}
impl<I> From<I> for PyDecoder
where
I: Into<DecoderWrapper>,
{
fn from(dec: I) -> Self {
PyDecoder {
decoder: dec.into().into(),
}
}
}
impl Decoder for PyDecoderWrapper {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
match self {
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode(tokens),
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode(tokens),
}
}
}
#[cfg(test)]
mod test {
use std::sync::{Arc, RwLock};
use pyo3::prelude::*;
use tk::decoders::metaspace::Metaspace;
use tk::decoders::DecoderWrapper;
use crate::decoders::{CustomDecoder, PyDecoder, PyDecoderWrapper};
#[test]
fn get_subtype() {
let py_dec = PyDecoder::new(Metaspace::default().into());
let py_meta = py_dec.get_as_subtype().unwrap();
let gil = Python::acquire_gil();
assert_eq!(
"tokenizers.decoders.Metaspace",
py_meta.as_ref(gil.python()).get_type().name()
);
}
#[test]
fn serialize() {
let py_wrapped: PyDecoderWrapper = Metaspace::default().into();
let py_ser = serde_json::to_string(&py_wrapped).unwrap();
let rs_wrapped = DecoderWrapper::Metaspace(Metaspace::default());
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
assert_eq!(py_ser, rs_ser);
let py_dec: PyDecoder = serde_json::from_str(&rs_ser).unwrap();
match py_dec.decoder {
PyDecoderWrapper::Wrapped(msp) => match *msp.as_ref().read().unwrap() {
DecoderWrapper::Metaspace(_) => {}
_ => panic!("Expected Metaspace"),
},
_ => panic!("Expected wrapped, not custom."),
}
let obj = Python::with_gil(|py| {
let py_msp = PyDecoder::new(Metaspace::default().into());
let obj: PyObject = Py::new(py, py_msp).unwrap().into_py(py);
obj
});
let py_seq = PyDecoderWrapper::Custom(Arc::new(RwLock::new(CustomDecoder::new(obj))));
assert!(serde_json::to_string(&py_seq).is_err());
}
}