extern crate tokenizers as tk; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; use pyo3::PyObjectProtocol; use std::collections::HashMap; use super::decoders::Decoder; use super::encoding::Encoding; use super::error::{PyError, ToPyResult}; use super::models::Model; use super::normalizers::Normalizer; use super::pre_tokenizers::PreTokenizer; use super::processors::PostProcessor; use super::trainers::Trainer; use super::utils::Container; use tk::tokenizer::{ PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy, }; #[pyclass(dict)] pub struct AddedToken { pub token: tk::tokenizer::AddedToken, } #[pymethods] impl AddedToken { #[new] #[args(kwargs = "**")] fn new(content: &str, kwargs: Option<&PyDict>) -> PyResult { let mut token = tk::tokenizer::AddedToken::from(content.to_owned()); if let Some(kwargs) = kwargs { for (key, value) in kwargs { let key: &str = key.extract()?; match key { "single_word" => token = token.single_word(value.extract()?), "lstrip" => token = token.lstrip(value.extract()?), "rstrip" => token = token.rstrip(value.extract()?), _ => println!("Ignored unknown kwarg option {}", key), } } } Ok(AddedToken { token }) } #[getter] fn get_content(&self) -> &str { &self.token.content } #[getter] fn get_rstrip(&self) -> bool { self.token.rstrip } #[getter] fn get_lstrip(&self) -> bool { self.token.lstrip } #[getter] fn get_single_word(&self) -> bool { self.token.single_word } } #[pyproto] impl PyObjectProtocol for AddedToken { fn __str__(&'p self) -> PyResult<&'p str> { Ok(&self.token.content) } fn __repr__(&self) -> PyResult { let bool_to_python = |p| match p { true => "True", false => "False", }; Ok(format!( "AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={})", self.token.content, bool_to_python(self.token.rstrip), bool_to_python(self.token.lstrip), bool_to_python(self.token.single_word) )) } } #[pyclass(dict)] pub struct Tokenizer { tokenizer: tk::tokenizer::Tokenizer, } #[pymethods] impl Tokenizer { #[new] fn new(mut model: PyRefMut) -> PyResult { if let Some(model) = model.model.to_pointer() { let tokenizer = tk::tokenizer::Tokenizer::new(model); Ok(Tokenizer { tokenizer }) } else { Err(exceptions::Exception::py_err( "The Model is already being used in another Tokenizer", )) } } fn num_special_tokens_to_add(&self, is_pair: bool) -> PyResult { Ok(self .tokenizer .get_post_processor() .map_or(0, |p| p.as_ref().added_tokens(is_pair))) } #[args(with_added_tokens = true)] fn get_vocab(&self, with_added_tokens: bool) -> PyResult> { Ok(self.tokenizer.get_vocab(with_added_tokens)) } #[args(with_added_tokens = true)] fn get_vocab_size(&self, with_added_tokens: bool) -> PyResult { Ok(self.tokenizer.get_vocab_size(with_added_tokens)) } #[args(kwargs = "**")] fn enable_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> { let mut stride = 0; let mut strategy = TruncationStrategy::LongestFirst; if let Some(kwargs) = kwargs { for (key, value) in kwargs { let key: &str = key.extract()?; match key { "stride" => stride = value.extract()?, "strategy" => { let value: &str = value.extract()?; strategy = match value { "longest_first" => Ok(TruncationStrategy::LongestFirst), "only_first" => Ok(TruncationStrategy::OnlyFirst), "only_second" => Ok(TruncationStrategy::OnlySecond), _ => Err(PyError(format!( "Unknown `strategy`: `{}`. Use \ one of `longest_first`, `only_first`, or `only_second`", value )) .into_pyerr()), }? } _ => println!("Ignored unknown kwarg option {}", key), } } } self.tokenizer.with_truncation(Some(TruncationParams { max_length, stride, strategy, })); Ok(()) } fn no_truncation(&mut self) { self.tokenizer.with_truncation(None); } #[args(kwargs = "**")] fn enable_padding(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> { let mut direction = PaddingDirection::Right; let mut pad_id: u32 = 0; let mut pad_type_id: u32 = 0; let mut pad_token = String::from("[PAD]"); let mut max_length: Option = None; if let Some(kwargs) = kwargs { for (key, value) in kwargs { let key: &str = key.extract()?; match key { "direction" => { let value: &str = value.extract()?; direction = match value { "left" => Ok(PaddingDirection::Left), "right" => Ok(PaddingDirection::Right), other => Err(PyError(format!( "Unknown `direction`: `{}`. Use \ one of `left` or `right`", other )) .into_pyerr()), }?; } "pad_id" => pad_id = value.extract()?, "pad_type_id" => pad_type_id = value.extract()?, "pad_token" => pad_token = value.extract()?, "max_length" => max_length = value.extract()?, _ => println!("Ignored unknown kwarg option {}", key), } } } let strategy = if let Some(max_length) = max_length { PaddingStrategy::Fixed(max_length) } else { PaddingStrategy::BatchLongest }; self.tokenizer.with_padding(Some(PaddingParams { strategy, direction, pad_id, pad_type_id, pad_token: pad_token.to_owned(), })); Ok(()) } fn no_padding(&mut self) { self.tokenizer.with_padding(None); } fn normalize(&self, sentence: &str) -> PyResult { ToPyResult( self.tokenizer .normalize(sentence) .map(|s| s.get().to_owned()), ) .into() } #[args(add_special_tokens = true)] fn encode( &self, sentence: &str, pair: Option<&str>, add_special_tokens: bool, ) -> PyResult { ToPyResult( self.tokenizer .encode( if let Some(pair) = pair { tk::tokenizer::EncodeInput::Dual(sentence.to_owned(), pair.to_owned()) } else { tk::tokenizer::EncodeInput::Single(sentence.to_owned()) }, add_special_tokens, ) .map(Encoding::new), ) .into() } #[args(add_special_tokens = true)] fn encode_batch( &self, sentences: &PyList, add_special_tokens: bool, ) -> PyResult> { let inputs = sentences .into_iter() .map(|item| { if let Ok(s1) = item.extract::() { Ok(tk::tokenizer::EncodeInput::Single(s1)) } else if let Ok((s1, s2)) = item.extract::<(String, String)>() { Ok(tk::tokenizer::EncodeInput::Dual(s1, s2)) } else { Err(exceptions::Exception::py_err( "Input must be a list[str] or list[(str, str)]", )) } }) .collect::>>()?; ToPyResult( self.tokenizer .encode_batch(inputs, add_special_tokens) .map(|encodings| encodings.into_iter().map(Encoding::new).collect()), ) .into() } fn decode(&self, ids: Vec, skip_special_tokens: Option) -> PyResult { ToPyResult( self.tokenizer .decode(ids, skip_special_tokens.unwrap_or(true)), ) .into() } fn decode_batch( &self, sentences: Vec>, skip_special_tokens: Option, ) -> PyResult> { ToPyResult( self.tokenizer .decode_batch(sentences, skip_special_tokens.unwrap_or(true)), ) .into() } fn token_to_id(&self, token: &str) -> Option { self.tokenizer.token_to_id(token) } fn id_to_token(&self, id: u32) -> Option { self.tokenizer.id_to_token(id) } fn add_tokens(&mut self, tokens: &PyList) -> PyResult { let tokens = tokens .into_iter() .map(|token| { if let Ok(content) = token.extract::() { Ok(tk::tokenizer::AddedToken { content, ..Default::default() }) } else if let Ok(token) = token.extract::>() { Ok(token.token.clone()) } else { Err(exceptions::Exception::py_err( "Input must be a List[Union[str, AddedToken]]", )) } }) .collect::>>()?; Ok(self.tokenizer.add_tokens(&tokens)) } fn add_special_tokens(&mut self, tokens: &PyList) -> PyResult { let tokens = tokens .into_iter() .map(|token| { if let Ok(content) = token.extract::() { Ok(tk::tokenizer::AddedToken { content, ..Default::default() }) } else if let Ok(token) = token.extract::>() { Ok(token.token.clone()) } else { Err(exceptions::Exception::py_err( "Input must be a List[Union[str, AddedToken]]", )) } }) .collect::>>()?; Ok(self.tokenizer.add_special_tokens(&tokens)) } fn train(&mut self, trainer: &Trainer, files: Vec) -> PyResult<()> { trainer.trainer.execute(|trainer| { if let Err(e) = self.tokenizer.train(trainer, files) { Err(exceptions::Exception::py_err(format!("{}", e))) } else { Ok(()) } }) } #[args(pair = "None", add_special_tokens = true)] fn post_process( &self, encoding: &Encoding, pair: Option<&Encoding>, add_special_tokens: bool, ) -> PyResult { ToPyResult( self.tokenizer .post_process( encoding.encoding.clone(), pair.map(|p| p.encoding.clone()), add_special_tokens, ) .map(Encoding::new), ) .into() } #[getter] fn get_model(&self) -> PyResult { Ok(Model { model: Container::from_ref(self.tokenizer.get_model()), }) } #[setter] fn set_model(&mut self, mut model: PyRefMut) -> PyResult<()> { if let Some(model) = model.model.to_pointer() { self.tokenizer.with_model(model); Ok(()) } else { Err(exceptions::Exception::py_err( "The Model is already being used in another Tokenizer", )) } } #[getter] fn get_normalizer(&self) -> PyResult> { Ok(self .tokenizer .get_normalizer() .map(|normalizer| Normalizer { normalizer: Container::from_ref(normalizer), })) } #[setter] fn set_normalizer(&mut self, mut normalizer: PyRefMut) -> PyResult<()> { if let Some(normalizer) = normalizer.normalizer.to_pointer() { self.tokenizer.with_normalizer(normalizer); Ok(()) } else { Err(exceptions::Exception::py_err( "The Normalizer is already being used in another Tokenizer", )) } } #[getter] fn get_pre_tokenizer(&self) -> PyResult> { Ok(self .tokenizer .get_pre_tokenizer() .map(|pretok| PreTokenizer { pretok: Container::from_ref(pretok), })) } #[setter] fn set_pre_tokenizer(&mut self, mut pretok: PyRefMut) -> PyResult<()> { if let Some(pretok) = pretok.pretok.to_pointer() { self.tokenizer.with_pre_tokenizer(pretok); Ok(()) } else { Err(exceptions::Exception::py_err( "The PreTokenizer is already being used in another Tokenizer", )) } } #[getter] fn get_post_processor(&self) -> PyResult> { Ok(self .tokenizer .get_post_processor() .map(|processor| PostProcessor { processor: Container::from_ref(processor), })) } #[setter] fn set_post_processor(&mut self, mut processor: PyRefMut) -> PyResult<()> { if let Some(processor) = processor.processor.to_pointer() { self.tokenizer.with_post_processor(processor); Ok(()) } else { Err(exceptions::Exception::py_err( "The Processor is already being used in another Tokenizer", )) } } #[getter] fn get_decoder(&self) -> PyResult> { Ok(self.tokenizer.get_decoder().map(|decoder| Decoder { decoder: Container::from_ref(decoder), })) } #[setter] fn set_decoder(&mut self, mut decoder: PyRefMut) -> PyResult<()> { if let Some(decoder) = decoder.decoder.to_pointer() { self.tokenizer.with_decoder(decoder); Ok(()) } else { Err(exceptions::Exception::py_err( "The Decoder is already being used in another Tokenizer", )) } } }