Rust - Fix offsets when there are added tokens

This commit is contained in:
Anthony MOI
2020-03-19 12:53:03 -04:00
parent 2aeae555e2
commit d953d58cee
6 changed files with 87 additions and 11 deletions

View File

@@ -18,6 +18,7 @@ special tokens. This is activated by default. ([#193](https://github.com/hugging
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE ([#193](https://github.com/huggingface/tokenizers/pull/193)):
- when `add_prefix_space=True`
- when a Unicode character gets split-up in multiple byte-level characters ([#156](https://github.com/huggingface/tokenizers/issues/156))
- Fix a bug where offsets were wrong when there was any added tokens in the sequence being encoded.
## How to migrate:
- Add the `ByteLevel` `PostProcessor` to your byte-level BPE tokenizers if relevant. If you are

View File

@@ -21,6 +21,7 @@ one anymore. ([#197](https://github.com/huggingface/tokenizers/pull/197))
- Fix some issues with the offsets being wrong with the `ByteLevel` BPE:
- when `add_prefix_space` is activated
- when a Unicode character gets split-up in multiple byte-level characters ([#156](https://github.com/huggingface/tokenizers/issues/156))
- Fix a bug where offsets were wrong when there was any added tokens in the sequence being encoded.
## How to migrate:
- Add the `ByteLevel` `PostProcessor` to your byte-level BPE tokenizers if relevant.

View File

@@ -6,7 +6,7 @@ dir_guard=@mkdir -p $(@D)
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt
BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt
TESTS_RESOURCES = $(SHARED_RESOURCES)
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/bert-base-uncased-vocab.txt
.PHONY : build
build :
@@ -49,6 +49,10 @@ $(DATA_DIR)/gpt2-% :
$(dir_guard)
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-$* -O $@
$(DATA_DIR)/bert-% :
$(dir_guard)
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-$* -O $@
$(DATA_DIR)/big.txt :
$(dir_guard)
wget https://norvig.com/big.txt -O $@

View File

@@ -60,6 +60,17 @@ pub struct BertNormalizer {
lowercase: bool,
}
impl Default for BertNormalizer {
fn default() -> Self {
Self {
clean_text: true,
handle_chinese_chars: true,
strip_accents: true,
lowercase: true,
}
}
}
impl BertNormalizer {
pub fn new(
clean_text: bool,

View File

@@ -390,20 +390,29 @@ impl Tokenizer {
return Ok((Encoding::default(), NormalizedString::from("")));
}
// Merge encodings and normalized strings
let others = encodings.split_off(1);
let mut first: Encoding = encodings.into_iter().next().unwrap();
let n_others = normalized.split_off(1);
for encoding in others {
first.merge_with(encoding, true);
let mut final_encoding: Encoding = encodings.into_iter().next().unwrap();
let mut final_normalized: NormalizedString = normalized.into_iter().next().unwrap();
let mut offset = final_normalized.len_original();
for (mut encoding, normalized) in others.into_iter().zip(n_others) {
encoding
.get_offsets_mut()
.iter_mut()
.for_each(|(start, end)| {
*start += offset;
*end += offset;
});
offset += normalized.len();
final_encoding.merge_with(encoding, false);
final_normalized.merge_with(&normalized);
}
let others = normalized.split_off(1);
let mut normalized: NormalizedString = normalized.into_iter().next().unwrap();
for n in others {
normalized.merge_with(&n);
}
Ok((first, normalized))
Ok((final_encoding, final_normalized))
};
let (sentence, pair) = match input {

View File

@@ -1,5 +1,10 @@
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
use tokenizers::models::bpe::BPE;
use tokenizers::models::wordpiece::WordPiece;
use tokenizers::normalizers::bert::BertNormalizer;
use tokenizers::pre_tokenizers::bert::BertPreTokenizer;
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::processors::bert::BertProcessing;
use tokenizers::tokenizer::{get_range_of, EncodeInput, Tokenizer};
fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
@@ -17,6 +22,29 @@ fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
tokenizer
}
fn get_bert() -> Tokenizer {
let mut tokenizer = Tokenizer::new(Box::new(
WordPiece::from_files("data/bert-base-uncased-vocab.txt")
.build()
.expect("Files not found, run `make test` to download these files"),
));
tokenizer.with_normalizer(Box::new(BertNormalizer::default()));
tokenizer.with_pre_tokenizer(Box::new(BertPreTokenizer));
tokenizer.with_decoder(Box::new(WordPieceDecoder::default()));
tokenizer.with_post_processor(Box::new(BertProcessing::new(
(
String::from("[SEP]"),
tokenizer.get_model().token_to_id("[SEP]").unwrap(),
),
(
String::from("[CLS]"),
tokenizer.get_model().token_to_id("[CLS]").unwrap(),
),
)));
tokenizer
}
#[inline]
fn offset_as_range(offset: (usize, usize)) -> std::ops::Range<usize> {
offset.0..offset.1
@@ -152,3 +180,25 @@ fn byte_level_double_sequence() {
]
);
}
#[test]
fn split_on_added_tokens_bert() {
let input = String::from("Yesterday I saw a [MASK] far away");
let mut tokenizer = get_bert();
tokenizer.add_special_tokens(&["[MASK]"]);
let output = tokenizer.encode(EncodeInput::Single(input), false).unwrap();
assert_eq!(
output.get_offsets(),
&[
(0, 9),
(10, 11),
(12, 15),
(16, 17),
(18, 24),
(25, 28),
(29, 33)
]
);
}