mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-04 08:19:21 +00:00
Basic python bindings
This commit is contained in:
@ -1,20 +1,162 @@
|
||||
extern crate tokenizers as tk;
|
||||
use tk::models::bpe::Error as BpeError;
|
||||
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyDict, PyList};
|
||||
|
||||
#[pyclass]
|
||||
struct WhitespaceTokenizer {}
|
||||
#[repr(transparent)]
|
||||
struct Token {
|
||||
tok: tk::tokenizer::Token,
|
||||
}
|
||||
impl Token {
|
||||
pub fn new(tok: tk::tokenizer::Token) -> Self {
|
||||
Token { tok }
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl WhitespaceTokenizer {
|
||||
impl Token {
|
||||
#[getter]
|
||||
fn get_id(&self) -> PyResult<u32> {
|
||||
Ok(self.tok.id)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_value(&self) -> PyResult<&str> {
|
||||
Ok(&self.tok.value)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_offsets(&self) -> PyResult<(usize, usize)> {
|
||||
Ok(self.tok.offsets)
|
||||
}
|
||||
|
||||
fn as_tuple(&self) -> PyResult<(u32, &str, (usize, usize))> {
|
||||
Ok((self.tok.id, &self.tok.value, self.tok.offsets))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pre_tokenizer(name: &str) -> Option<Box<dyn tk::tokenizer::PreTokenizer + Sync>> {
|
||||
match name {
|
||||
"ByteLevel" => Some(Box::new(tk::pre_tokenizers::byte_level::ByteLevel)),
|
||||
"Whitespace" => Some(Box::new(tk::pre_tokenizers::whitespace::Whitespace)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_normalizer(_name: &str) -> Option<Box<dyn tk::tokenizer::Normalizer + Sync>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn get_post_processor(_name: &str) -> Option<Box<dyn tk::tokenizer::PostProcessor + Sync>> {
|
||||
None
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct Tokenizer {
|
||||
tokenizer: tk::tokenizer::Tokenizer,
|
||||
}
|
||||
#[pymethods]
|
||||
impl Tokenizer {
|
||||
#[staticmethod]
|
||||
fn tokenize(s: String) -> PyResult<Vec<String>> {
|
||||
Ok(tk::WhitespaceTokenizer::tokenize(&s))
|
||||
#[args(kwargs = "**")]
|
||||
fn bpe_from_files(vocab: &str, merges: &str, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||
let model = match tk::models::bpe::BPE::from_files(vocab, merges) {
|
||||
Ok(bpe) => Ok(Box::new(bpe)),
|
||||
Err(e) => match e {
|
||||
BpeError::BadVocabulary => {
|
||||
Err(exceptions::Exception::py_err("Bad vocab.json format"))
|
||||
}
|
||||
BpeError::Io(io) => Err(PyErr::from(io)),
|
||||
BpeError::JsonError(_) => Err(exceptions::Exception::py_err(
|
||||
"Error while parsing vocab json file",
|
||||
)),
|
||||
BpeError::MergeTokenOutOfVocabulary(token) => Err(exceptions::Exception::py_err(
|
||||
format!("Merge token out of vocabulary: {}", token),
|
||||
)),
|
||||
},
|
||||
}?;
|
||||
|
||||
let mut tokenizer = tk::tokenizer::Tokenizer::new(model);
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (option, value) in kwargs {
|
||||
match option.to_string().as_ref() {
|
||||
"pre_tokenizer" => {
|
||||
let value = value.to_string();
|
||||
if let Some(pre_tokenizer) = get_pre_tokenizer(&value) {
|
||||
tokenizer.with_pre_tokenizer(pre_tokenizer);
|
||||
} else {
|
||||
return Err(exceptions::Exception::py_err(format!(
|
||||
"PreTokenizer `{}` not found",
|
||||
value
|
||||
)));
|
||||
}
|
||||
}
|
||||
"normalizers" => {
|
||||
let mut normalizers = vec![];
|
||||
let values = value.cast_as::<PyList>()?;
|
||||
for value in values {
|
||||
let value = value.to_string();
|
||||
if let Some(normalizer) = get_normalizer(&value) {
|
||||
normalizers.push(normalizer);
|
||||
} else {
|
||||
return Err(exceptions::Exception::py_err(format!(
|
||||
"Normalizer `{}` not found",
|
||||
value
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
"post_processors" => {
|
||||
let mut processors = vec![];
|
||||
let values = value.cast_as::<PyList>()?;
|
||||
for value in values {
|
||||
let value = value.to_string();
|
||||
if let Some(processor) = get_post_processor(&value) {
|
||||
processors.push(processor);
|
||||
} else {
|
||||
return Err(exceptions::Exception::py_err(format!(
|
||||
"PostProcessor `{}` not found",
|
||||
value
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => println!("Ignored unknown kwarg `{}`", option),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Tokenizer { tokenizer })
|
||||
}
|
||||
|
||||
fn encode(&self, sentence: &str) -> Vec<Token> {
|
||||
self.tokenizer
|
||||
.encode(sentence)
|
||||
.into_iter()
|
||||
.map(|token| Token::new(token))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn encode_batch(&self, sentences: Vec<&str>) -> Vec<Vec<Token>> {
|
||||
self.tokenizer
|
||||
.encode_batch(sentences)
|
||||
.into_iter()
|
||||
.map(|sentence| {
|
||||
sentence
|
||||
.into_iter()
|
||||
.map(|token| Token::new(token))
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn tokenizers(py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<WhitespaceTokenizer>()?;
|
||||
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<Tokenizer>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user