mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Update Python bindings for Encoding
This commit is contained in:
15
bindings/python/src/encoding.rs
Normal file
15
bindings/python/src/encoding.rs
Normal file
@ -0,0 +1,15 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use pyo3::prelude::*;
|
||||
|
||||
#[pyclass]
|
||||
#[repr(transparent)]
|
||||
pub struct Encoding {
|
||||
encoding: tk::tokenizer::Encoding,
|
||||
}
|
||||
|
||||
impl Encoding {
|
||||
pub fn new(encoding: tk::tokenizer::Encoding) -> Self {
|
||||
Encoding { encoding }
|
||||
}
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
mod decoders;
|
||||
mod encoding;
|
||||
mod models;
|
||||
mod pre_tokenizers;
|
||||
mod token;
|
||||
|
@ -2,8 +2,10 @@ extern crate tokenizers as tk;
|
||||
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
|
||||
use super::decoders::Decoder;
|
||||
use super::encoding::Encoding;
|
||||
use super::models::Model;
|
||||
use super::pre_tokenizers::PreTokenizer;
|
||||
use super::token::Token;
|
||||
@ -62,25 +64,36 @@ impl Tokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode(&self, sentence: &str) -> Vec<Token> {
|
||||
self.tokenizer
|
||||
.encode(sentence)
|
||||
.into_iter()
|
||||
.map(|token| Token::new(token))
|
||||
.collect()
|
||||
fn encode(&self, sentence: &str, pair: Option<&str>) -> Encoding {
|
||||
Encoding::new(self.tokenizer.encode(if pair.is_some() {
|
||||
tk::tokenizer::EncodeInput::Dual(sentence.to_owned(), pair.unwrap().to_owned())
|
||||
} else {
|
||||
tk::tokenizer::EncodeInput::Single(sentence.to_owned())
|
||||
}))
|
||||
}
|
||||
|
||||
fn encode_batch(&self, sentences: Vec<&str>) -> Vec<Vec<Token>> {
|
||||
self.tokenizer
|
||||
.encode_batch(sentences)
|
||||
fn encode_batch(&self, sentences: &PyList) -> PyResult<Vec<Encoding>> {
|
||||
let inputs = sentences
|
||||
.into_iter()
|
||||
.map(|sentence| {
|
||||
sentence
|
||||
.into_iter()
|
||||
.map(|token| Token::new(token))
|
||||
.collect()
|
||||
.map(|item| {
|
||||
if let Ok(s1) = item.extract::<String>() {
|
||||
Ok(tk::tokenizer::EncodeInput::Single(s1))
|
||||
} else if let Ok((s1, s2)) = item.extract::<(String, String)>() {
|
||||
Ok(tk::tokenizer::EncodeInput::Dual(s1, s2))
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Input must be a list[str] or list[(str, str)]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
Ok(self
|
||||
.tokenizer
|
||||
.encode_batch(inputs)
|
||||
.into_iter()
|
||||
.map(|encoding| Encoding::new(encoding))
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn decode(&self, ids: Vec<u32>) -> String {
|
||||
@ -109,4 +122,3 @@ impl Tokenizer {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user