From 5f6e9784526a4cd5e4f6dcdcc045cdceba5463e1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 26 Sep 2022 18:00:41 +0200 Subject: [PATCH] Fixing roberta type id (everything is zero). (#1072) * Fixing roberta type ids (everything is zero). * We need to fix type_ids for all sequence even when not changing anything else. * Fixing tests hopefully better. --- bindings/python/Cargo.lock | 4 ++-- bindings/python/Cargo.toml | 3 ++- bindings/python/Makefile | 2 +- tokenizers/src/processors/roberta.rs | 15 ++++++++++----- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/bindings/python/Cargo.lock b/bindings/python/Cargo.lock index 9c626c7f..3500a5eb 100644 --- a/bindings/python/Cargo.lock +++ b/bindings/python/Cargo.lock @@ -1706,7 +1706,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokenizers" -version = "0.13.0" +version = "0.13.1" dependencies = [ "aho-corasick", "cached-path", @@ -1739,7 +1739,7 @@ dependencies = [ [[package]] name = "tokenizers-python" -version = "0.13.0" +version = "0.13.1" dependencies = [ "env_logger", "itertools 0.9.0", diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index ff983a4c..5dc63f80 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -14,7 +14,7 @@ serde = { version = "1.0", features = [ "rc", "derive" ]} serde_json = "1.0" libc = "0.2" env_logger = "0.7.1" -pyo3 = { version = "0.16.2", features = ["extension-module"] } +pyo3 = { version = "0.16.2" } numpy = "0.16.2" ndarray = "0.13" onig = { version = "6.0", default-features = false } @@ -28,5 +28,6 @@ path = "../../tokenizers" tempfile = "3.1" [features] +default = ["pyo3/extension-module"] test = ["pyo3/auto-initialize"] diff --git a/bindings/python/Makefile b/bindings/python/Makefile index ae588ed4..fca412f3 100644 --- a/bindings/python/Makefile +++ b/bindings/python/Makefile @@ -20,7 +20,7 @@ TESTS_RESOURCES = $(DATA_DIR)/small.txt $(DATA_DIR)/roberta.json test: $(TESTS_RESOURCES) pip install pytest requests setuptools_rust numpy pyarrow datasets python -m pytest -s -v tests - cargo test --features test + cargo test --no-default-features --features test $(DATA_DIR)/big.txt : $(dir_guard) diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index 74916418..3af9a8d6 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -70,6 +70,11 @@ impl PostProcessor for RobertaProcessing { } } + // Roberta is weird, and every encoding is type_id=0. + encodings + .iter_mut() + .for_each(|encoding| encoding.set_type_ids(vec![0; encoding.len()])); + if !add_special_tokens { return Ok(encodings); } @@ -110,7 +115,7 @@ impl PostProcessor for RobertaProcessing { .map(|encoding| { let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); - let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); + let type_ids = vec![0; encoding.get_ids().len() + 2]; let tokens = [ &[self.cls.0.clone()], encoding.get_tokens(), @@ -146,7 +151,7 @@ impl PostProcessor for RobertaProcessing { ) } else { let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); - let pair_type_ids = vec![1; encoding.get_ids().len() + 2]; + let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; let pair_tokens = [ &[self.sep.0.clone()], encoding.get_tokens(), @@ -176,7 +181,7 @@ impl PostProcessor for RobertaProcessing { .map(|encoding| { let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); - let pair_type_ids = vec![1; encoding.get_ids().len() + 2]; + let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; let pair_tokens = [ &[self.sep.0.clone()], encoding.get_tokens(), @@ -280,7 +285,7 @@ mod tests { pair_encoding, Encoding::new( vec![0, 12, 14, 2, 2, 15, 2], - vec![0, 0, 0, 0, 1, 1, 1], + vec![0, 0, 0, 0, 0, 0, 0], vec![ "".into(), "Hello".into(), @@ -310,7 +315,7 @@ mod tests { pair_encoding, Encoding::new( vec![12, 14, 15], - vec![0, 0, 1], + vec![0, 0, 0], vec!["Hello".into(), "there".into(), "pair".into(),], vec![None, None, None], vec![(0, 5), (6, 11), (0, 4)],