From d953d58cee2b3b0d7946c79cd8cd83ad61350c0f Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 19 Mar 2020 12:53:03 -0400 Subject: [PATCH] Rust - Fix offsets when there are added tokens --- bindings/python/CHANGELOG.md | 1 + tokenizers/CHANGELOG.md | 1 + tokenizers/Makefile | 6 +++- tokenizers/src/normalizers/bert.rs | 11 +++++++ tokenizers/src/tokenizer/mod.rs | 29 +++++++++++------ tokenizers/tests/offsets.rs | 50 ++++++++++++++++++++++++++++++ 6 files changed, 87 insertions(+), 11 deletions(-) diff --git a/bindings/python/CHANGELOG.md b/bindings/python/CHANGELOG.md index 13a4a0fe..0460e76a 100644 --- a/bindings/python/CHANGELOG.md +++ b/bindings/python/CHANGELOG.md @@ -18,6 +18,7 @@ special tokens. This is activated by default. ([#193](https://github.com/hugging - Fix some issues with the offsets being wrong with the `ByteLevel` BPE ([#193](https://github.com/huggingface/tokenizers/pull/193)): - when `add_prefix_space=True` - when a Unicode character gets split-up in multiple byte-level characters ([#156](https://github.com/huggingface/tokenizers/issues/156)) +- Fix a bug where offsets were wrong when there was any added tokens in the sequence being encoded. ## How to migrate: - Add the `ByteLevel` `PostProcessor` to your byte-level BPE tokenizers if relevant. If you are diff --git a/tokenizers/CHANGELOG.md b/tokenizers/CHANGELOG.md index c4b8c957..f046b02a 100644 --- a/tokenizers/CHANGELOG.md +++ b/tokenizers/CHANGELOG.md @@ -21,6 +21,7 @@ one anymore. ([#197](https://github.com/huggingface/tokenizers/pull/197)) - Fix some issues with the offsets being wrong with the `ByteLevel` BPE: - when `add_prefix_space` is activated - when a Unicode character gets split-up in multiple byte-level characters ([#156](https://github.com/huggingface/tokenizers/issues/156)) +- Fix a bug where offsets were wrong when there was any added tokens in the sequence being encoded. ## How to migrate: - Add the `ByteLevel` `PostProcessor` to your byte-level BPE tokenizers if relevant. diff --git a/tokenizers/Makefile b/tokenizers/Makefile index 851661a8..e77de885 100644 --- a/tokenizers/Makefile +++ b/tokenizers/Makefile @@ -6,7 +6,7 @@ dir_guard=@mkdir -p $(@D) SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt -TESTS_RESOURCES = $(SHARED_RESOURCES) +TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/bert-base-uncased-vocab.txt .PHONY : build build : @@ -49,6 +49,10 @@ $(DATA_DIR)/gpt2-% : $(dir_guard) wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-$* -O $@ +$(DATA_DIR)/bert-% : + $(dir_guard) + wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-$* -O $@ + $(DATA_DIR)/big.txt : $(dir_guard) wget https://norvig.com/big.txt -O $@ diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index ad0e588c..370a858b 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -60,6 +60,17 @@ pub struct BertNormalizer { lowercase: bool, } +impl Default for BertNormalizer { + fn default() -> Self { + Self { + clean_text: true, + handle_chinese_chars: true, + strip_accents: true, + lowercase: true, + } + } +} + impl BertNormalizer { pub fn new( clean_text: bool, diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 31fb78fe..e08c7fa5 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -390,20 +390,29 @@ impl Tokenizer { return Ok((Encoding::default(), NormalizedString::from(""))); } + // Merge encodings and normalized strings let others = encodings.split_off(1); - let mut first: Encoding = encodings.into_iter().next().unwrap(); + let n_others = normalized.split_off(1); - for encoding in others { - first.merge_with(encoding, true); + let mut final_encoding: Encoding = encodings.into_iter().next().unwrap(); + let mut final_normalized: NormalizedString = normalized.into_iter().next().unwrap(); + + let mut offset = final_normalized.len_original(); + for (mut encoding, normalized) in others.into_iter().zip(n_others) { + encoding + .get_offsets_mut() + .iter_mut() + .for_each(|(start, end)| { + *start += offset; + *end += offset; + }); + offset += normalized.len(); + + final_encoding.merge_with(encoding, false); + final_normalized.merge_with(&normalized); } - let others = normalized.split_off(1); - let mut normalized: NormalizedString = normalized.into_iter().next().unwrap(); - for n in others { - normalized.merge_with(&n); - } - - Ok((first, normalized)) + Ok((final_encoding, final_normalized)) }; let (sentence, pair) = match input { diff --git a/tokenizers/tests/offsets.rs b/tokenizers/tests/offsets.rs index 4b6a5bbf..3133398f 100644 --- a/tokenizers/tests/offsets.rs +++ b/tokenizers/tests/offsets.rs @@ -1,5 +1,10 @@ +use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder; use tokenizers::models::bpe::BPE; +use tokenizers::models::wordpiece::WordPiece; +use tokenizers::normalizers::bert::BertNormalizer; +use tokenizers::pre_tokenizers::bert::BertPreTokenizer; use tokenizers::pre_tokenizers::byte_level::ByteLevel; +use tokenizers::processors::bert::BertProcessing; use tokenizers::tokenizer::{get_range_of, EncodeInput, Tokenizer}; fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer { @@ -17,6 +22,29 @@ fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer { tokenizer } +fn get_bert() -> Tokenizer { + let mut tokenizer = Tokenizer::new(Box::new( + WordPiece::from_files("data/bert-base-uncased-vocab.txt") + .build() + .expect("Files not found, run `make test` to download these files"), + )); + tokenizer.with_normalizer(Box::new(BertNormalizer::default())); + tokenizer.with_pre_tokenizer(Box::new(BertPreTokenizer)); + tokenizer.with_decoder(Box::new(WordPieceDecoder::default())); + tokenizer.with_post_processor(Box::new(BertProcessing::new( + ( + String::from("[SEP]"), + tokenizer.get_model().token_to_id("[SEP]").unwrap(), + ), + ( + String::from("[CLS]"), + tokenizer.get_model().token_to_id("[CLS]").unwrap(), + ), + ))); + + tokenizer +} + #[inline] fn offset_as_range(offset: (usize, usize)) -> std::ops::Range { offset.0..offset.1 @@ -152,3 +180,25 @@ fn byte_level_double_sequence() { ] ); } + +#[test] +fn split_on_added_tokens_bert() { + let input = String::from("Yesterday I saw a [MASK] far away"); + + let mut tokenizer = get_bert(); + tokenizer.add_special_tokens(&["[MASK]"]); + let output = tokenizer.encode(EncodeInput::Single(input), false).unwrap(); + + assert_eq!( + output.get_offsets(), + &[ + (0, 9), + (10, 11), + (12, 15), + (16, 17), + (18, 24), + (25, 28), + (29, 33) + ] + ); +}