From 47cef0e13adec08358a945d2c6be4aeb3c08458d Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Fri, 6 Mar 2020 12:20:39 -0500 Subject: [PATCH] Python - Fix BPE and WordPiece builders usage --- bindings/python/src/models.rs | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 828aca0a..e5cfe6d1 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -51,10 +51,7 @@ impl BPE { #[staticmethod] #[args(kwargs = "**")] fn from_files(vocab: &str, merges: &str, kwargs: Option<&PyDict>) -> PyResult { - let builder: PyResult<_> = - ToPyResult(tk::models::bpe::BPE::from_files(vocab, merges)).into(); - let mut builder = builder?; - + let mut builder = tk::models::bpe::BPE::from_files(vocab, merges); if let Some(kwargs) = kwargs { for (key, value) in kwargs { let key: &str = key.extract()?; @@ -115,25 +112,27 @@ impl WordPiece { #[staticmethod] #[args(kwargs = "**")] fn from_files(vocab: &str, kwargs: Option<&PyDict>) -> PyResult { - let mut unk_token = String::from("[UNK]"); - let mut max_input_chars_per_word = Some(100); + let mut builder = tk::models::wordpiece::WordPiece::from_files(vocab); if let Some(kwargs) = kwargs { for (key, val) in kwargs { let key: &str = key.extract()?; match key { - "unk_token" => unk_token = val.extract()?, - "max_input_chars_per_word" => max_input_chars_per_word = Some(val.extract()?), + "unk_token" => { + builder = builder.unk_token(val.extract()?); + } + "max_input_chars_per_word" => { + builder = builder.max_input_chars_per_word(val.extract()?); + } + "continuing_subword_prefix" => { + builder = builder.continuing_subword_prefix(val.extract()?); + } _ => println!("Ignored unknown kwargs option {}", key), } } } - match tk::models::wordpiece::WordPiece::from_files( - vocab, - unk_token, - max_input_chars_per_word, - ) { + match builder.build() { Err(e) => { println!("Errors: {:?}", e); Err(exceptions::Exception::py_err(