From fee1d4e8a37d87d551a2a5e8886c1a91707eb1d6 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 10 Sep 2020 14:51:39 -0400 Subject: [PATCH] TemplateProcessing - Add @narsil suggestions Co-authored-by: Nicolas Patry --- .../py_src/tokenizers/processors/__init__.pyi | 5 ++ tokenizers/src/processors/template.rs | 50 ++++++++++++++----- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/bindings/python/py_src/tokenizers/processors/__init__.pyi b/bindings/python/py_src/tokenizers/processors/__init__.pyi index 5cde2b90..7e713a89 100644 --- a/bindings/python/py_src/tokenizers/processors/__init__.pyi +++ b/bindings/python/py_src/tokenizers/processors/__init__.pyi @@ -129,6 +129,11 @@ class TemplateProcessing(PostProcessor): Note that we are saying the "default" type_id because each SpecialToken can define its own type_id which would override the provided default. + + **Warning**: You must ensure that you are giving the correct tokens/ids as these + will be added to the Encoding without any further check. If the given ids correspond + to something totally different in a `Tokenizer` using this `PostProcessor`, it + might lead to unexpected results. """ def __init__(self, seq_a: Template, seq_b: Template, special_tokens: Tokens) -> None: diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index df816a48..63c34c9d 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -38,11 +38,17 @@ //! Note that we are saying the "default" `type_id` because each `SpecialToken` can define //! its own `type_id` which would override the provided default. //! +//! **Warning**: You must ensure that you are giving the correct tokens/ids as these will +//! be added to the `Encoding` without any further check. If the given ids correspond to +//! something totally different in a `Tokenizer` using this `PostProcessor`, it might lead +//! to unexpected results. +//! //! [`TemplateProcessing`]: struct.TemplateProcessing.html //! use crate::{Encoding, PostProcessor, Result}; +use itertools::Itertools; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; /// Represents the different kind of pieces that constitute a template. /// It can be either the input sequence or a [`SpecialToken`]: @@ -337,14 +343,15 @@ impl TemplateProcessingBuilder { .special_tokens .as_ref() .map_or(false, |map| map.0.contains_key(sp)); + match exist { - true => Ok(()), - false => Err(format!("The special token `{}` is missing", sp)), + false => Some(sp), + true => None, } }; let empty = vec![]; - let seq = self + let missing: HashSet<&str> = self .sequence_a .as_ref() .map_or(empty.iter(), |s| s.0.iter()) @@ -352,16 +359,21 @@ impl TemplateProcessingBuilder { self.sequence_b .as_ref() .map_or(empty.iter(), |s| s.0.iter()), - ); + ) + .filter_map(|piece| match piece { + Piece::Sequence { .. } => None, + Piece::SpecialToken { id } => check(id.as_ref()), + }) + .collect::>(); - for piece in seq { - match piece { - Piece::Sequence { .. } => {} - Piece::SpecialToken { id } => check(id)?, - } + if missing.is_empty() { + Ok(()) + } else { + Err(format!( + "Missing SpecialToken(s) with id(s) `{}`", + missing.iter().join(", ") + )) } - - Ok(()) } } @@ -402,7 +414,7 @@ impl TemplateProcessing { self.special_tokens .0 .get(id) - .ok_or(format!("Missing SpecialToken with id {}", id))? + .ok_or_else(|| format!("Missing SpecialToken with id {}", id))? .ids .len() } else { @@ -647,6 +659,18 @@ mod tests { ); } + #[test] + fn missing_special_tokens() { + let processor = TemplateProcessing::builder() + .sequence_a("[CLS] $0 [SEP]") + .sequence_b("$1 [SEP]") + .build(); + + let err_a = Err("Missing SpecialToken(s) with id(s) `[SEP], [CLS]`".into()); + let err_b = Err("Missing SpecialToken(s) with id(s) `[CLS], [SEP]`".into()); + assert!(processor == err_a || processor == err_b); + } + #[test] fn template_processing() { let processor = tests::get_bert_template();