diff --git a/bindings/python/Cargo.lock b/bindings/python/Cargo.lock index 793e2631..ca6eb48d 100644 --- a/bindings/python/Cargo.lock +++ b/bindings/python/Cargo.lock @@ -456,14 +456,14 @@ dependencies = [ "regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)", "serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)", - "unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)", + "unicode-normalization 0.1.11 (git+https://github.com/n1t0/unicode-normalization)", "unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] name = "unicode-normalization" version = "0.1.11" -source = "registry+https://github.com/rust-lang/crates.io-index" +source = "git+https://github.com/n1t0/unicode-normalization#894053d92493c55c89fe9b188c0fb2babaa9a84c" dependencies = [ "smallvec 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -570,7 +570,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum syn 1.0.11 (registry+https://github.com/rust-lang/crates.io-index)" = "dff0acdb207ae2fe6d5976617f887eb1e35a2ba52c13c7234c790960cdad9238" "checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" "checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b" -"checksum unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)" = "b561e267b2326bb4cebfc0ef9e68355c7abe6c6f522aeac2f5bf95d56c59bdcf" +"checksum unicode-normalization 0.1.11 (git+https://github.com/n1t0/unicode-normalization)" = "" "checksum unicode-width 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "caaa9d531767d1ff2150b9332433f32a24622147e5ebb1f26409d5da67afd479" "checksum unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c" "checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 8befd28f..c05da22a 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -5,7 +5,7 @@ use super::utils::Container; use pyo3::prelude::*; use pyo3::types::*; use std::collections::HashSet; -use tk::tokenizer::Result; +use tk::tokenizer::{Offsets, Result}; #[pyclass(dict)] pub struct PreTokenizer { @@ -21,7 +21,7 @@ impl PreTokenizer { }) } - fn pre_tokenize(&self, s: &str) -> PyResult> { + fn pre_tokenize(&self, s: &str) -> PyResult> { ToPyResult(self.pretok.execute(|pretok| pretok.pre_tokenize(s))).into() } } @@ -58,36 +58,9 @@ pub struct BertPreTokenizer {} #[pymethods] impl BertPreTokenizer { #[staticmethod] - #[args(kwargs = "**")] - fn new(kwargs: Option<&PyDict>) -> PyResult { - let mut do_basic_tokenize = true; - let mut do_lower_case = true; - let mut never_split = HashSet::new(); - let mut tokenize_chinese_chars = true; - - if let Some(kwargs) = kwargs { - for (key, val) in kwargs { - let key: &str = key.extract()?; - match key { - "do_basic_tokenize" => do_basic_tokenize = val.extract()?, - "do_lower_case" => do_lower_case = val.extract()?, - "tokenize_chinese_chars" => tokenize_chinese_chars = val.extract()?, - "never_split" => { - let values: Vec = val.extract()?; - never_split = values.into_iter().collect(); - } - _ => println!("Ignored unknown kwargs option {}", key), - } - } - } - + fn new() -> PyResult { Ok(PreTokenizer { - pretok: Container::Owned(Box::new(tk::pre_tokenizers::bert::BertPreTokenizer::new( - do_basic_tokenize, - do_lower_case, - never_split, - tokenize_chinese_chars, - ))), + pretok: Container::Owned(Box::new(tk::pre_tokenizers::bert::BertPreTokenizer)), }) } } @@ -104,7 +77,7 @@ impl PyPreTokenizer { } impl tk::tokenizer::PreTokenizer for PyPreTokenizer { - fn pre_tokenize(&self, sentence: &str) -> Result> { + fn pre_tokenize(&self, sentence: &str) -> Result> { let gil = Python::acquire_gil(); let py = gil.python(); @@ -112,9 +85,15 @@ impl tk::tokenizer::PreTokenizer for PyPreTokenizer { match self.class.call_method(py, "pre_tokenize", args, None) { Ok(res) => Ok(res .cast_as::(py) - .map_err(|_| PyError::from("`pre_tokenize is expected to return a List[str]"))? - .extract::>() - .map_err(|_| PyError::from("`pre_tokenize` is expected to return a List[str]"))?), + .map_err(|_| { + PyError::from("`pre_tokenize is expected to return a List[(str, (uint, uint))]") + })? + .extract::>() + .map_err(|_| { + PyError::from( + "`pre_tokenize` is expected to return a List[(str, (uint, uint))]", + ) + })?), Err(e) => { e.print(py); Err(Box::new(PyError::from(