mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 12:18:20 +00:00
TemplateProcessing - Add @narsil suggestions
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@@ -129,6 +129,11 @@ class TemplateProcessing(PostProcessor):
|
||||
|
||||
Note that we are saying the "default" type_id because each SpecialToken can define
|
||||
its own type_id which would override the provided default.
|
||||
|
||||
**Warning**: You must ensure that you are giving the correct tokens/ids as these
|
||||
will be added to the Encoding without any further check. If the given ids correspond
|
||||
to something totally different in a `Tokenizer` using this `PostProcessor`, it
|
||||
might lead to unexpected results.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_a: Template, seq_b: Template, special_tokens: Tokens) -> None:
|
||||
|
||||
@@ -38,11 +38,17 @@
|
||||
//! Note that we are saying the "default" `type_id` because each `SpecialToken` can define
|
||||
//! its own `type_id` which would override the provided default.
|
||||
//!
|
||||
//! **Warning**: You must ensure that you are giving the correct tokens/ids as these will
|
||||
//! be added to the `Encoding` without any further check. If the given ids correspond to
|
||||
//! something totally different in a `Tokenizer` using this `PostProcessor`, it might lead
|
||||
//! to unexpected results.
|
||||
//!
|
||||
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
|
||||
//!
|
||||
use crate::{Encoding, PostProcessor, Result};
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// Represents the different kind of pieces that constitute a template.
|
||||
/// It can be either the input sequence or a [`SpecialToken`]:
|
||||
@@ -337,14 +343,15 @@ impl TemplateProcessingBuilder {
|
||||
.special_tokens
|
||||
.as_ref()
|
||||
.map_or(false, |map| map.0.contains_key(sp));
|
||||
|
||||
match exist {
|
||||
true => Ok(()),
|
||||
false => Err(format!("The special token `{}` is missing", sp)),
|
||||
false => Some(sp),
|
||||
true => None,
|
||||
}
|
||||
};
|
||||
|
||||
let empty = vec![];
|
||||
let seq = self
|
||||
let missing: HashSet<&str> = self
|
||||
.sequence_a
|
||||
.as_ref()
|
||||
.map_or(empty.iter(), |s| s.0.iter())
|
||||
@@ -352,16 +359,21 @@ impl TemplateProcessingBuilder {
|
||||
self.sequence_b
|
||||
.as_ref()
|
||||
.map_or(empty.iter(), |s| s.0.iter()),
|
||||
);
|
||||
)
|
||||
.filter_map(|piece| match piece {
|
||||
Piece::Sequence { .. } => None,
|
||||
Piece::SpecialToken { id } => check(id.as_ref()),
|
||||
})
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
for piece in seq {
|
||||
match piece {
|
||||
Piece::Sequence { .. } => {}
|
||||
Piece::SpecialToken { id } => check(id)?,
|
||||
}
|
||||
if missing.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!(
|
||||
"Missing SpecialToken(s) with id(s) `{}`",
|
||||
missing.iter().join(", ")
|
||||
))
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -402,7 +414,7 @@ impl TemplateProcessing {
|
||||
self.special_tokens
|
||||
.0
|
||||
.get(id)
|
||||
.ok_or(format!("Missing SpecialToken with id {}", id))?
|
||||
.ok_or_else(|| format!("Missing SpecialToken with id {}", id))?
|
||||
.ids
|
||||
.len()
|
||||
} else {
|
||||
@@ -647,6 +659,18 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_special_tokens() {
|
||||
let processor = TemplateProcessing::builder()
|
||||
.sequence_a("[CLS] $0 [SEP]")
|
||||
.sequence_b("$1 [SEP]")
|
||||
.build();
|
||||
|
||||
let err_a = Err("Missing SpecialToken(s) with id(s) `[SEP], [CLS]`".into());
|
||||
let err_b = Err("Missing SpecialToken(s) with id(s) `[CLS], [SEP]`".into());
|
||||
assert!(processor == err_a || processor == err_b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn template_processing() {
|
||||
let processor = tests::get_bert_template();
|
||||
|
||||
Reference in New Issue
Block a user