mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 20:28:22 +00:00
TemplateProcessing - Add @narsil suggestions
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@@ -129,6 +129,11 @@ class TemplateProcessing(PostProcessor):
|
|||||||
|
|
||||||
Note that we are saying the "default" type_id because each SpecialToken can define
|
Note that we are saying the "default" type_id because each SpecialToken can define
|
||||||
its own type_id which would override the provided default.
|
its own type_id which would override the provided default.
|
||||||
|
|
||||||
|
**Warning**: You must ensure that you are giving the correct tokens/ids as these
|
||||||
|
will be added to the Encoding without any further check. If the given ids correspond
|
||||||
|
to something totally different in a `Tokenizer` using this `PostProcessor`, it
|
||||||
|
might lead to unexpected results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, seq_a: Template, seq_b: Template, special_tokens: Tokens) -> None:
|
def __init__(self, seq_a: Template, seq_b: Template, special_tokens: Tokens) -> None:
|
||||||
|
|||||||
@@ -38,11 +38,17 @@
|
|||||||
//! Note that we are saying the "default" `type_id` because each `SpecialToken` can define
|
//! Note that we are saying the "default" `type_id` because each `SpecialToken` can define
|
||||||
//! its own `type_id` which would override the provided default.
|
//! its own `type_id` which would override the provided default.
|
||||||
//!
|
//!
|
||||||
|
//! **Warning**: You must ensure that you are giving the correct tokens/ids as these will
|
||||||
|
//! be added to the `Encoding` without any further check. If the given ids correspond to
|
||||||
|
//! something totally different in a `Tokenizer` using this `PostProcessor`, it might lead
|
||||||
|
//! to unexpected results.
|
||||||
|
//!
|
||||||
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
|
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
|
||||||
//!
|
//!
|
||||||
use crate::{Encoding, PostProcessor, Result};
|
use crate::{Encoding, PostProcessor, Result};
|
||||||
|
use itertools::Itertools;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
/// Represents the different kind of pieces that constitute a template.
|
/// Represents the different kind of pieces that constitute a template.
|
||||||
/// It can be either the input sequence or a [`SpecialToken`]:
|
/// It can be either the input sequence or a [`SpecialToken`]:
|
||||||
@@ -337,14 +343,15 @@ impl TemplateProcessingBuilder {
|
|||||||
.special_tokens
|
.special_tokens
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map_or(false, |map| map.0.contains_key(sp));
|
.map_or(false, |map| map.0.contains_key(sp));
|
||||||
|
|
||||||
match exist {
|
match exist {
|
||||||
true => Ok(()),
|
false => Some(sp),
|
||||||
false => Err(format!("The special token `{}` is missing", sp)),
|
true => None,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let empty = vec![];
|
let empty = vec![];
|
||||||
let seq = self
|
let missing: HashSet<&str> = self
|
||||||
.sequence_a
|
.sequence_a
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map_or(empty.iter(), |s| s.0.iter())
|
.map_or(empty.iter(), |s| s.0.iter())
|
||||||
@@ -352,16 +359,21 @@ impl TemplateProcessingBuilder {
|
|||||||
self.sequence_b
|
self.sequence_b
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map_or(empty.iter(), |s| s.0.iter()),
|
.map_or(empty.iter(), |s| s.0.iter()),
|
||||||
);
|
)
|
||||||
|
.filter_map(|piece| match piece {
|
||||||
for piece in seq {
|
Piece::Sequence { .. } => None,
|
||||||
match piece {
|
Piece::SpecialToken { id } => check(id.as_ref()),
|
||||||
Piece::Sequence { .. } => {}
|
})
|
||||||
Piece::SpecialToken { id } => check(id)?,
|
.collect::<HashSet<_>>();
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if missing.is_empty() {
|
||||||
Ok(())
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!(
|
||||||
|
"Missing SpecialToken(s) with id(s) `{}`",
|
||||||
|
missing.iter().join(", ")
|
||||||
|
))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -402,7 +414,7 @@ impl TemplateProcessing {
|
|||||||
self.special_tokens
|
self.special_tokens
|
||||||
.0
|
.0
|
||||||
.get(id)
|
.get(id)
|
||||||
.ok_or(format!("Missing SpecialToken with id {}", id))?
|
.ok_or_else(|| format!("Missing SpecialToken with id {}", id))?
|
||||||
.ids
|
.ids
|
||||||
.len()
|
.len()
|
||||||
} else {
|
} else {
|
||||||
@@ -647,6 +659,18 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn missing_special_tokens() {
|
||||||
|
let processor = TemplateProcessing::builder()
|
||||||
|
.sequence_a("[CLS] $0 [SEP]")
|
||||||
|
.sequence_b("$1 [SEP]")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let err_a = Err("Missing SpecialToken(s) with id(s) `[SEP], [CLS]`".into());
|
||||||
|
let err_b = Err("Missing SpecialToken(s) with id(s) `[CLS], [SEP]`".into());
|
||||||
|
assert!(processor == err_a || processor == err_b);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn template_processing() {
|
fn template_processing() {
|
||||||
let processor = tests::get_bert_template();
|
let processor = tests::get_bert_template();
|
||||||
|
|||||||
Reference in New Issue
Block a user