extern crate tokenizers as tk; use super::error::{PyError, ToPyResult}; use super::utils::Container; use pyo3::prelude::*; use pyo3::types::*; use tk::tokenizer::Result; #[pyclass(dict)] pub struct Decoder { pub decoder: Container, } #[pymethods] impl Decoder { #[staticmethod] fn custom(decoder: PyObject) -> PyResult { let decoder = PyDecoder::new(decoder)?; Ok(Decoder { decoder: Container::Owned(Box::new(decoder)), }) } fn decode(&self, tokens: Vec) -> PyResult { ToPyResult(self.decoder.execute(|decoder| decoder.decode(tokens))).into() } } #[pyclass] pub struct ByteLevel {} #[pymethods] impl ByteLevel { #[staticmethod] fn new() -> PyResult { Ok(Decoder { decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel::new(false))), }) } } #[pyclass] pub struct WordPiece {} #[pymethods] impl WordPiece { #[staticmethod] fn new(kwargs: Option<&PyDict>) -> PyResult { let mut prefix = String::from("##"); if let Some(kwargs) = kwargs { if let Some(p) = kwargs.get_item("prefix") { prefix = p.extract()?; } } Ok(Decoder { decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new(prefix))), }) } } #[pyclass] pub struct Metaspace {} #[pymethods] impl Metaspace { #[staticmethod] #[args(kwargs = "**")] fn new(kwargs: Option<&PyDict>) -> PyResult { let mut replacement = '▁'; let mut add_prefix_space = true; if let Some(kwargs) = kwargs { for (key, value) in kwargs { let key: &str = key.extract()?; match key { "replacement" => { let s: &str = value.extract()?; replacement = s.chars().nth(0).ok_or(exceptions::Exception::py_err( "replacement must be a character", ))?; } "add_prefix_space" => add_prefix_space = value.extract()?, _ => println!("Ignored unknown kwarg option {}", key), } } } Ok(Decoder { decoder: Container::Owned(Box::new(tk::decoder::metaspace::Metaspace::new( replacement, add_prefix_space, ))), }) } } struct PyDecoder { class: PyObject, } impl PyDecoder { pub fn new(class: PyObject) -> PyResult { Ok(PyDecoder { class }) } } impl tk::tokenizer::Decoder for PyDecoder { fn decode(&self, tokens: Vec) -> Result { let gil = Python::acquire_gil(); let py = gil.python(); let args = PyTuple::new(py, &[tokens]); match self.class.call_method(py, "decode", args, None) { Ok(res) => Ok(res .cast_as::(py) .map_err(|_| PyError::from("`decode` is expected to return a str"))? .to_string() .map_err(|_| PyError::from("`decode` is expected to return a str"))? .into_owned()), Err(e) => { e.print(py); Err(Box::new(PyError::from("Error while calling `decode`"))) } } } }