mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-13 05:48:36 +00:00
Python - Add ability to create custom Decoder
This commit is contained in:
@@ -1,12 +1,24 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
use super::utils::Container;
|
use super::utils::Container;
|
||||||
|
use pyo3::exceptions;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::types::*;
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
pub struct Decoder {
|
pub struct Decoder {
|
||||||
pub decoder: Container<dyn tk::tokenizer::Decoder + Sync>,
|
pub decoder: Container<dyn tk::tokenizer::Decoder + Sync>,
|
||||||
}
|
}
|
||||||
|
#[pymethods]
|
||||||
|
impl Decoder {
|
||||||
|
#[staticmethod]
|
||||||
|
fn custom(decoder: PyObject) -> PyResult<Self> {
|
||||||
|
let decoder = PyDecoder::new(decoder)?;
|
||||||
|
Ok(Decoder {
|
||||||
|
decoder: Container::Owned(Box::new(decoder)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
pub struct ByteLevel {}
|
pub struct ByteLevel {}
|
||||||
@@ -19,3 +31,53 @@ impl ByteLevel {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct PyDecoder {
|
||||||
|
class: PyObject,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PyDecoder {
|
||||||
|
pub fn new(class: PyObject) -> PyResult<Self> {
|
||||||
|
let decoder = PyDecoder { class };
|
||||||
|
|
||||||
|
// Quickly test the PyDecoder
|
||||||
|
decoder._decode(vec![
|
||||||
|
"This".into(),
|
||||||
|
"is".into(),
|
||||||
|
"a".into(),
|
||||||
|
"sentence".into(),
|
||||||
|
])?;
|
||||||
|
|
||||||
|
Ok(decoder)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _decode(&self, tokens: Vec<String>) -> PyResult<String> {
|
||||||
|
let gil = Python::acquire_gil();
|
||||||
|
let py = gil.python();
|
||||||
|
|
||||||
|
let args = PyTuple::new(py, &[tokens]);
|
||||||
|
let res = self.class.call_method(py, "decode", args, None)?;
|
||||||
|
|
||||||
|
let decoded = res
|
||||||
|
.cast_as::<PyString>(py)
|
||||||
|
.map_err(|_| exceptions::TypeError::py_err("`decode` is expected to return a str"))?;
|
||||||
|
|
||||||
|
Ok(decoded.to_string()?.into_owned())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tk::tokenizer::Decoder for PyDecoder {
|
||||||
|
fn decode(&self, tokens: Vec<String>) -> String {
|
||||||
|
match self._decode(tokens) {
|
||||||
|
Ok(res) => res,
|
||||||
|
Err(e) => {
|
||||||
|
let gil = Python::acquire_gil();
|
||||||
|
let py = gil.python();
|
||||||
|
e.print(py);
|
||||||
|
|
||||||
|
// Return an empty string as fallback
|
||||||
|
String::from("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ impl tk::tokenizer::PreTokenizer for PyPreTokenizer {
|
|||||||
let py = gil.python();
|
let py = gil.python();
|
||||||
e.print(py);
|
e.print(py);
|
||||||
|
|
||||||
// Return an empty string as fallback
|
// Return an empty Vec as fallback
|
||||||
vec![]
|
vec![]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user