mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-02 07:19:24 +00:00
125 lines
3.4 KiB
Rust
125 lines
3.4 KiB
Rust
extern crate tokenizers as tk;
|
|
|
|
use super::error::{PyError, ToPyResult};
|
|
use super::utils::Container;
|
|
use pyo3::prelude::*;
|
|
use pyo3::types::*;
|
|
use tk::tokenizer::Result;
|
|
|
|
#[pyclass(dict)]
|
|
pub struct Decoder {
|
|
pub decoder: Container<dyn tk::tokenizer::Decoder + Sync>,
|
|
}
|
|
#[pymethods]
|
|
impl Decoder {
|
|
#[staticmethod]
|
|
fn custom(decoder: PyObject) -> PyResult<Self> {
|
|
let decoder = PyDecoder::new(decoder)?;
|
|
Ok(Decoder {
|
|
decoder: Container::Owned(Box::new(decoder)),
|
|
})
|
|
}
|
|
|
|
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
|
|
ToPyResult(self.decoder.execute(|decoder| decoder.decode(tokens))).into()
|
|
}
|
|
}
|
|
|
|
#[pyclass]
|
|
pub struct ByteLevel {}
|
|
#[pymethods]
|
|
impl ByteLevel {
|
|
#[staticmethod]
|
|
fn new() -> PyResult<Decoder> {
|
|
Ok(Decoder {
|
|
decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel::new(false))),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[pyclass]
|
|
pub struct WordPiece {}
|
|
#[pymethods]
|
|
impl WordPiece {
|
|
#[staticmethod]
|
|
fn new(kwargs: Option<&PyDict>) -> PyResult<Decoder> {
|
|
let mut prefix = String::from("##");
|
|
|
|
if let Some(kwargs) = kwargs {
|
|
if let Some(p) = kwargs.get_item("prefix") {
|
|
prefix = p.extract()?;
|
|
}
|
|
}
|
|
|
|
Ok(Decoder {
|
|
decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new(prefix))),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[pyclass]
|
|
pub struct Metaspace {}
|
|
#[pymethods]
|
|
impl Metaspace {
|
|
#[staticmethod]
|
|
#[args(kwargs = "**")]
|
|
fn new(kwargs: Option<&PyDict>) -> PyResult<Decoder> {
|
|
let mut replacement = '▁';
|
|
let mut add_prefix_space = true;
|
|
|
|
if let Some(kwargs) = kwargs {
|
|
for (key, value) in kwargs {
|
|
let key: &str = key.extract()?;
|
|
match key {
|
|
"replacement" => {
|
|
let s: &str = value.extract()?;
|
|
replacement = s.chars().nth(0).ok_or(exceptions::Exception::py_err(
|
|
"replacement must be a character",
|
|
))?;
|
|
}
|
|
"add_prefix_space" => add_prefix_space = value.extract()?,
|
|
_ => println!("Ignored unknown kwarg option {}", key),
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(Decoder {
|
|
decoder: Container::Owned(Box::new(tk::decoder::metaspace::Metaspace::new(
|
|
replacement,
|
|
add_prefix_space,
|
|
))),
|
|
})
|
|
}
|
|
}
|
|
|
|
struct PyDecoder {
|
|
class: PyObject,
|
|
}
|
|
|
|
impl PyDecoder {
|
|
pub fn new(class: PyObject) -> PyResult<Self> {
|
|
Ok(PyDecoder { class })
|
|
}
|
|
}
|
|
|
|
impl tk::tokenizer::Decoder for PyDecoder {
|
|
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
|
let gil = Python::acquire_gil();
|
|
let py = gil.python();
|
|
|
|
let args = PyTuple::new(py, &[tokens]);
|
|
match self.class.call_method(py, "decode", args, None) {
|
|
Ok(res) => Ok(res
|
|
.cast_as::<PyString>(py)
|
|
.map_err(|_| PyError::from("`decode` is expected to return a str"))?
|
|
.to_string()
|
|
.map_err(|_| PyError::from("`decode` is expected to return a str"))?
|
|
.into_owned()),
|
|
Err(e) => {
|
|
e.print(py);
|
|
Err(Box::new(PyError::from("Error while calling `decode`")))
|
|
}
|
|
}
|
|
}
|
|
}
|