diff --git a/bindings/python/src/encoding.rs b/bindings/python/src/encoding.rs new file mode 100644 index 00000000..84e41be5 --- /dev/null +++ b/bindings/python/src/encoding.rs @@ -0,0 +1,15 @@ +extern crate tokenizers as tk; + +use pyo3::prelude::*; + +#[pyclass] +#[repr(transparent)] +pub struct Encoding { + encoding: tk::tokenizer::Encoding, +} + +impl Encoding { + pub fn new(encoding: tk::tokenizer::Encoding) -> Self { + Encoding { encoding } + } +} diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index f1d64d7f..7fd6b3da 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -1,4 +1,5 @@ mod decoders; +mod encoding; mod models; mod pre_tokenizers; mod token; diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 2fc7800c..bc30e2de 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -2,8 +2,10 @@ extern crate tokenizers as tk; use pyo3::exceptions; use pyo3::prelude::*; +use pyo3::types::*; use super::decoders::Decoder; +use super::encoding::Encoding; use super::models::Model; use super::pre_tokenizers::PreTokenizer; use super::token::Token; @@ -62,25 +64,36 @@ impl Tokenizer { } } - fn encode(&self, sentence: &str) -> Vec { - self.tokenizer - .encode(sentence) - .into_iter() - .map(|token| Token::new(token)) - .collect() + fn encode(&self, sentence: &str, pair: Option<&str>) -> Encoding { + Encoding::new(self.tokenizer.encode(if pair.is_some() { + tk::tokenizer::EncodeInput::Dual(sentence.to_owned(), pair.unwrap().to_owned()) + } else { + tk::tokenizer::EncodeInput::Single(sentence.to_owned()) + })) } - fn encode_batch(&self, sentences: Vec<&str>) -> Vec> { - self.tokenizer - .encode_batch(sentences) + fn encode_batch(&self, sentences: &PyList) -> PyResult> { + let inputs = sentences .into_iter() - .map(|sentence| { - sentence - .into_iter() - .map(|token| Token::new(token)) - .collect() + .map(|item| { + if let Ok(s1) = item.extract::() { + Ok(tk::tokenizer::EncodeInput::Single(s1)) + } else if let Ok((s1, s2)) = item.extract::<(String, String)>() { + Ok(tk::tokenizer::EncodeInput::Dual(s1, s2)) + } else { + Err(exceptions::Exception::py_err( + "Input must be a list[str] or list[(str, str)]", + )) + } }) - .collect() + .collect::>>()?; + + Ok(self + .tokenizer + .encode_batch(inputs) + .into_iter() + .map(|encoding| Encoding::new(encoding)) + .collect()) } fn decode(&self, ids: Vec) -> String { @@ -109,4 +122,3 @@ impl Tokenizer { }) } } -