mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-09 14:18:30 +00:00
Python - Add ability to create custom Decoder
This commit is contained in:
@@ -1,12 +1,24 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use super::utils::Container;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
|
||||
#[pyclass]
|
||||
pub struct Decoder {
|
||||
pub decoder: Container<dyn tk::tokenizer::Decoder + Sync>,
|
||||
}
|
||||
#[pymethods]
|
||||
impl Decoder {
|
||||
#[staticmethod]
|
||||
fn custom(decoder: PyObject) -> PyResult<Self> {
|
||||
let decoder = PyDecoder::new(decoder)?;
|
||||
Ok(Decoder {
|
||||
decoder: Container::Owned(Box::new(decoder)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub struct ByteLevel {}
|
||||
@@ -19,3 +31,53 @@ impl ByteLevel {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct PyDecoder {
|
||||
class: PyObject,
|
||||
}
|
||||
|
||||
impl PyDecoder {
|
||||
pub fn new(class: PyObject) -> PyResult<Self> {
|
||||
let decoder = PyDecoder { class };
|
||||
|
||||
// Quickly test the PyDecoder
|
||||
decoder._decode(vec![
|
||||
"This".into(),
|
||||
"is".into(),
|
||||
"a".into(),
|
||||
"sentence".into(),
|
||||
])?;
|
||||
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
fn _decode(&self, tokens: Vec<String>) -> PyResult<String> {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
let args = PyTuple::new(py, &[tokens]);
|
||||
let res = self.class.call_method(py, "decode", args, None)?;
|
||||
|
||||
let decoded = res
|
||||
.cast_as::<PyString>(py)
|
||||
.map_err(|_| exceptions::TypeError::py_err("`decode` is expected to return a str"))?;
|
||||
|
||||
Ok(decoded.to_string()?.into_owned())
|
||||
}
|
||||
}
|
||||
|
||||
impl tk::tokenizer::Decoder for PyDecoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> String {
|
||||
match self._decode(tokens) {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
e.print(py);
|
||||
|
||||
// Return an empty string as fallback
|
||||
String::from("")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ impl tk::tokenizer::PreTokenizer for PyPreTokenizer {
|
||||
let py = gil.python();
|
||||
e.print(py);
|
||||
|
||||
// Return an empty string as fallback
|
||||
// Return an empty Vec as fallback
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user