Update Python bindings for Encoding

This commit is contained in:
Anthony MOI
2019-12-10 12:38:36 -05:00
parent 132a0fc4b4
commit 8cedc5f1f6
3 changed files with 44 additions and 16 deletions

View File

@ -0,0 +1,15 @@
extern crate tokenizers as tk;
use pyo3::prelude::*;
#[pyclass]
#[repr(transparent)]
pub struct Encoding {
encoding: tk::tokenizer::Encoding,
}
impl Encoding {
pub fn new(encoding: tk::tokenizer::Encoding) -> Self {
Encoding { encoding }
}
}

View File

@ -1,4 +1,5 @@
mod decoders; mod decoders;
mod encoding;
mod models; mod models;
mod pre_tokenizers; mod pre_tokenizers;
mod token; mod token;

View File

@ -2,8 +2,10 @@ extern crate tokenizers as tk;
use pyo3::exceptions; use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::*;
use super::decoders::Decoder; use super::decoders::Decoder;
use super::encoding::Encoding;
use super::models::Model; use super::models::Model;
use super::pre_tokenizers::PreTokenizer; use super::pre_tokenizers::PreTokenizer;
use super::token::Token; use super::token::Token;
@ -62,25 +64,36 @@ impl Tokenizer {
} }
} }
fn encode(&self, sentence: &str) -> Vec<Token> { fn encode(&self, sentence: &str, pair: Option<&str>) -> Encoding {
self.tokenizer Encoding::new(self.tokenizer.encode(if pair.is_some() {
.encode(sentence) tk::tokenizer::EncodeInput::Dual(sentence.to_owned(), pair.unwrap().to_owned())
.into_iter() } else {
.map(|token| Token::new(token)) tk::tokenizer::EncodeInput::Single(sentence.to_owned())
.collect() }))
} }
fn encode_batch(&self, sentences: Vec<&str>) -> Vec<Vec<Token>> { fn encode_batch(&self, sentences: &PyList) -> PyResult<Vec<Encoding>> {
self.tokenizer let inputs = sentences
.encode_batch(sentences)
.into_iter() .into_iter()
.map(|sentence| { .map(|item| {
sentence if let Ok(s1) = item.extract::<String>() {
.into_iter() Ok(tk::tokenizer::EncodeInput::Single(s1))
.map(|token| Token::new(token)) } else if let Ok((s1, s2)) = item.extract::<(String, String)>() {
.collect() Ok(tk::tokenizer::EncodeInput::Dual(s1, s2))
} else {
Err(exceptions::Exception::py_err(
"Input must be a list[str] or list[(str, str)]",
))
}
}) })
.collect() .collect::<PyResult<Vec<_>>>()?;
Ok(self
.tokenizer
.encode_batch(inputs)
.into_iter()
.map(|encoding| Encoding::new(encoding))
.collect())
} }
fn decode(&self, ids: Vec<u32>) -> String { fn decode(&self, ids: Vec<u32>) -> String {
@ -109,4 +122,3 @@ impl Tokenizer {
}) })
} }
} }