diff --git a/bindings/python/examples/example.py b/bindings/python/examples/example.py index 91f00d3e..e779fd98 100644 --- a/bindings/python/examples/example.py +++ b/bindings/python/examples/example.py @@ -61,8 +61,8 @@ elif args.type == "bert": print("Running Bert tokenizer") tok_p = BertTokenizer.from_pretrained('bert-base-uncased') - tok_r = Tokenizer(models.WordPiece.from_files(args.vocab)) - tok_r.with_pre_tokenizer(pre_tokenizers.BasicPreTokenizer.new()) + tok_r = Tokenizer(models.WordPiece.from_files(args.vocab, unk_token="[UNK]", max_input_chars_per_word=100)) + tok_r.with_pre_tokenizer(pre_tokenizers.BasicPreTokenizer.new(do_lower_case=True, tokenize_chinese_chars=True, never_split=[])) tok_r.with_decoder(decoders.WordPiece.new()) else: raise Exception(f"Unknown type {args.type}") diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 13e7b0e6..cd47931c 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -4,6 +4,7 @@ use super::utils::Container; use pyo3::exceptions; use pyo3::prelude::*; +use pyo3::types::*; /// A Model represents some tokenization algorithm like BPE or Word /// This class cannot be constructed directly. Please use one of the concrete models. @@ -71,10 +72,21 @@ impl WordPiece { /// /// Instantiate a new WordPiece model using the provided vocabulary file #[staticmethod] - fn from_files(vocab: &str) -> PyResult { - // TODO: Parse kwargs for these - let unk_token = String::from("[UNK]"); - let max_input_chars_per_word = Some(100); + #[args(kwargs = "**")] + fn from_files(vocab: &str, kwargs: Option<&PyDict>) -> PyResult { + let mut unk_token = String::from("[UNK]"); + let mut max_input_chars_per_word = Some(100); + + if let Some(kwargs) = kwargs { + for (key, val) in kwargs { + let key: &str = key.extract()?; + match key { + "unk_token" => unk_token = val.extract()?, + "max_input_chars_per_word" => max_input_chars_per_word = Some(val.extract()?), + _ => println!("Ignored unknown kwargs option {}", key), + } + } + } match tk::models::wordpiece::WordPiece::from_files( vocab, diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 350ba5e6..3a3cd99b 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -43,12 +43,27 @@ pub struct BasicPreTokenizer {} #[pymethods] impl BasicPreTokenizer { #[staticmethod] - fn new() -> PyResult { - // TODO: Parse kwargs for these + #[args(kwargs = "**")] + fn new(kwargs: Option<&PyDict>) -> PyResult { let mut do_lower_case = true; let mut never_split = HashSet::new(); let mut tokenize_chinese_chars = true; + if let Some(kwargs) = kwargs { + for (key, val) in kwargs { + let key: &str = key.extract()?; + match key { + "do_lower_case" => do_lower_case = val.extract()?, + "tokenize_chinese_chars" => tokenize_chinese_chars = val.extract()?, + "never_split" => { + let values: Vec = val.extract()?; + never_split = values.into_iter().collect(); + } + _ => println!("Ignored unknown kwargs option {}", key), + } + } + } + Ok(PreTokenizer { pretok: Container::Owned(Box::new(tk::pre_tokenizers::basic::BasicPreTokenizer::new( do_lower_case,