TemplateProcessing - Add @narsil suggestions

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Anthony MOI
2020-09-10 14:51:39 -04:00
committed by Anthony MOI
parent b7df6539e6
commit fee1d4e8a3
2 changed files with 42 additions and 13 deletions

View File

@@ -129,6 +129,11 @@ class TemplateProcessing(PostProcessor):
Note that we are saying the "default" type_id because each SpecialToken can define Note that we are saying the "default" type_id because each SpecialToken can define
its own type_id which would override the provided default. its own type_id which would override the provided default.
**Warning**: You must ensure that you are giving the correct tokens/ids as these
will be added to the Encoding without any further check. If the given ids correspond
to something totally different in a `Tokenizer` using this `PostProcessor`, it
might lead to unexpected results.
""" """
def __init__(self, seq_a: Template, seq_b: Template, special_tokens: Tokens) -> None: def __init__(self, seq_a: Template, seq_b: Template, special_tokens: Tokens) -> None:

View File

@@ -38,11 +38,17 @@
//! Note that we are saying the "default" `type_id` because each `SpecialToken` can define //! Note that we are saying the "default" `type_id` because each `SpecialToken` can define
//! its own `type_id` which would override the provided default. //! its own `type_id` which would override the provided default.
//! //!
//! **Warning**: You must ensure that you are giving the correct tokens/ids as these will
//! be added to the `Encoding` without any further check. If the given ids correspond to
//! something totally different in a `Tokenizer` using this `PostProcessor`, it might lead
//! to unexpected results.
//!
//! [`TemplateProcessing`]: struct.TemplateProcessing.html //! [`TemplateProcessing`]: struct.TemplateProcessing.html
//! //!
use crate::{Encoding, PostProcessor, Result}; use crate::{Encoding, PostProcessor, Result};
use itertools::Itertools;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
/// Represents the different kind of pieces that constitute a template. /// Represents the different kind of pieces that constitute a template.
/// It can be either the input sequence or a [`SpecialToken`]: /// It can be either the input sequence or a [`SpecialToken`]:
@@ -337,14 +343,15 @@ impl TemplateProcessingBuilder {
.special_tokens .special_tokens
.as_ref() .as_ref()
.map_or(false, |map| map.0.contains_key(sp)); .map_or(false, |map| map.0.contains_key(sp));
match exist { match exist {
true => Ok(()), false => Some(sp),
false => Err(format!("The special token `{}` is missing", sp)), true => None,
} }
}; };
let empty = vec![]; let empty = vec![];
let seq = self let missing: HashSet<&str> = self
.sequence_a .sequence_a
.as_ref() .as_ref()
.map_or(empty.iter(), |s| s.0.iter()) .map_or(empty.iter(), |s| s.0.iter())
@@ -352,16 +359,21 @@ impl TemplateProcessingBuilder {
self.sequence_b self.sequence_b
.as_ref() .as_ref()
.map_or(empty.iter(), |s| s.0.iter()), .map_or(empty.iter(), |s| s.0.iter()),
); )
.filter_map(|piece| match piece {
Piece::Sequence { .. } => None,
Piece::SpecialToken { id } => check(id.as_ref()),
})
.collect::<HashSet<_>>();
for piece in seq { if missing.is_empty() {
match piece { Ok(())
Piece::Sequence { .. } => {} } else {
Piece::SpecialToken { id } => check(id)?, Err(format!(
} "Missing SpecialToken(s) with id(s) `{}`",
missing.iter().join(", ")
))
} }
Ok(())
} }
} }
@@ -402,7 +414,7 @@ impl TemplateProcessing {
self.special_tokens self.special_tokens
.0 .0
.get(id) .get(id)
.ok_or(format!("Missing SpecialToken with id {}", id))? .ok_or_else(|| format!("Missing SpecialToken with id {}", id))?
.ids .ids
.len() .len()
} else { } else {
@@ -647,6 +659,18 @@ mod tests {
); );
} }
#[test]
fn missing_special_tokens() {
let processor = TemplateProcessing::builder()
.sequence_a("[CLS] $0 [SEP]")
.sequence_b("$1 [SEP]")
.build();
let err_a = Err("Missing SpecialToken(s) with id(s) `[SEP], [CLS]`".into());
let err_b = Err("Missing SpecialToken(s) with id(s) `[CLS], [SEP]`".into());
assert!(processor == err_a || processor == err_b);
}
#[test] #[test]
fn template_processing() { fn template_processing() {
let processor = tests::get_bert_template(); let processor = tests::get_bert_template();