mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Refactor metaspace (#1476)
* version = "0.15.3-dev-0” Improve performances of meta space, but also just fix it. (transformers) ➜ transformers git:(refactor-default-llama) ✗ python ../scripts/gemma-dummy.py Token indices sequence length is longer than the specified maximum sequence length for this model (14999 > 2048). Running this sequence through the model will result in indexing errors ['<REPR_END>', '▁inform', '<s>', '.', '▁Hey', '<unk>', '.', '▁', '▁', '▁', '▁', '▁', '▁', '▁.'] ['▁inform', '<s>', '.', '▁Hey', '<unk>', '.', '▁', '▁', '▁', '▁', '▁', '▁', '▁.'] [0.0006330013275146484, 0.0014591217041015625, 0.015890836715698242, 0.18584918975830078, 2.1726326942443848] (transformers) ➜ transformers git:(refactor-default-llama) ✗ python ../scripts/gemma-dummy.py Token indices sequence length is longer than the specified maximum sequence length for this model (10000 > 2048). Running this sequence through the model will result in indexing errors ['<REPR_END>', 'in', 'form', '<s>', '.', '▁Hey', '<unk>', '.', '▁▁▁▁▁▁', '▁.'] ['in', 'form', '<s>', '.', '▁Hey', '<unk>', '.', '▁▁▁▁▁▁', '▁.'] [0.0008409023284912109, 0.0008909702301025391, 0.00882411003112793, 0.10214710235595703, 1.187899112701416] * well what do we have * nit * be BC with non legacy * unrelated change for clippy * fix test * splitting is a must for word_ids * fmt and lint * Fixing everything (hopefully better). * Fixing node. * Including yarn.lock * Lint. * Stubs. * revert to use split * fix merge issues * fix tests * finish fixing tests * ruff --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@ -73,14 +73,19 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn decoder_serialization() {
|
||||
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
|
||||
let oldjson = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
|
||||
let olddecoder: DecoderWrapper = serde_json::from_str(oldjson).unwrap();
|
||||
let oldserialized = serde_json::to_string(&olddecoder).unwrap();
|
||||
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always","split":true}]}"#;
|
||||
assert_eq!(oldserialized, json);
|
||||
|
||||
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
|
||||
let serialized = serde_json::to_string(&decoder).unwrap();
|
||||
assert_eq!(serialized, json);
|
||||
}
|
||||
#[test]
|
||||
fn decoder_serialization_other_no_arg() {
|
||||
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
|
||||
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always","split":true}]}"#;
|
||||
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
|
||||
let serialized = serde_json::to_string(&decoder).unwrap();
|
||||
assert_eq!(serialized, json);
|
||||
@ -88,7 +93,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn decoder_serialization_no_decode() {
|
||||
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
|
||||
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always"}]}"#;
|
||||
assert!(serde_json::from_str::<DecoderWrapper>(json).is_err());
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use serde::{de, Deserialize, Deserializer, Serialize};
|
||||
|
||||
/// Enum representing options for the metaspace prepending scheme.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)]
|
||||
@ -19,8 +19,8 @@ pub enum PrependScheme {
|
||||
#[serde(tag = "type")]
|
||||
pub struct Metaspace {
|
||||
replacement: char,
|
||||
pub add_prefix_space: bool,
|
||||
pub prepend_scheme: PrependScheme,
|
||||
pub split: bool,
|
||||
#[serde(skip)]
|
||||
str_rep: String,
|
||||
}
|
||||
@ -44,42 +44,40 @@ impl<'de> Deserialize<'de> for Metaspace {
|
||||
#[serde(rename = "type")]
|
||||
_type: Type,
|
||||
replacement: char,
|
||||
pub add_prefix_space: bool,
|
||||
|
||||
pub add_prefix_space: Option<bool>,
|
||||
#[serde(default = "default_prepend_scheme_value")]
|
||||
pub prepend_scheme: PrependScheme,
|
||||
#[serde(skip, rename = "str_rep")]
|
||||
_str_rep: String,
|
||||
pub split: Option<bool>,
|
||||
#[serde(rename = "str_rep")]
|
||||
_str_rep: Option<String>,
|
||||
}
|
||||
|
||||
let helper = MetaspaceHelper::deserialize(deserializer)?;
|
||||
let instance = Self::new_with_prepend_scheme(
|
||||
let mut helper = MetaspaceHelper::deserialize(deserializer)?;
|
||||
if let Some(false) = helper.add_prefix_space {
|
||||
if helper.prepend_scheme != PrependScheme::Never {
|
||||
return Err(de::Error::custom(
|
||||
"add_prefix_space does not match declared prepend_scheme",
|
||||
));
|
||||
}
|
||||
helper.prepend_scheme = PrependScheme::Never;
|
||||
}
|
||||
let instance = Self::new(
|
||||
helper.replacement,
|
||||
helper.add_prefix_space,
|
||||
helper.prepend_scheme,
|
||||
helper.split.unwrap_or(true),
|
||||
);
|
||||
Ok(instance)
|
||||
}
|
||||
}
|
||||
|
||||
impl Metaspace {
|
||||
pub fn new(replacement: char, add_prefix_space: bool) -> Self {
|
||||
Self::new_with_prepend_scheme(
|
||||
replacement,
|
||||
add_prefix_space,
|
||||
PrependScheme::Always, // always prepend for legacy purpose
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_with_prepend_scheme(
|
||||
replacement: char,
|
||||
add_prefix_space: bool,
|
||||
prepend_scheme: PrependScheme,
|
||||
) -> Self {
|
||||
pub fn new(replacement: char, prepend_scheme: PrependScheme, split: bool) -> Self {
|
||||
Self {
|
||||
replacement,
|
||||
str_rep: replacement.to_string(),
|
||||
add_prefix_space,
|
||||
prepend_scheme,
|
||||
split,
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,6 +90,14 @@ impl Metaspace {
|
||||
self.str_rep = replacement.to_string();
|
||||
}
|
||||
|
||||
pub fn get_split(&self) -> bool {
|
||||
self.split
|
||||
}
|
||||
|
||||
pub fn set_split(&mut self, split: bool) {
|
||||
self.split = split;
|
||||
}
|
||||
|
||||
pub fn get_prepend_scheme(&self) -> PrependScheme {
|
||||
self.prepend_scheme
|
||||
}
|
||||
@ -103,28 +109,34 @@ impl Metaspace {
|
||||
|
||||
impl Default for Metaspace {
|
||||
fn default() -> Self {
|
||||
Self::new('▁', true)
|
||||
Self::new('▁', PrependScheme::Always, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl PreTokenizer for Metaspace {
|
||||
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
|
||||
let mut first_split = true;
|
||||
|
||||
pretokenized.split(|_, mut normalized| {
|
||||
normalized.replace(' ', &self.str_rep)?;
|
||||
if self.add_prefix_space && !normalized.get().starts_with(self.replacement) {
|
||||
if self.prepend_scheme == PrependScheme::Always {
|
||||
normalized.prepend(&self.str_rep);
|
||||
} else if self.prepend_scheme == PrependScheme::First && first_split {
|
||||
normalized.prepend(&self.str_rep);
|
||||
first_split = false;
|
||||
match self.prepend_scheme {
|
||||
PrependScheme::Always => {
|
||||
if !normalized.get().starts_with(self.replacement) {
|
||||
normalized.prepend(&self.str_rep);
|
||||
}
|
||||
}
|
||||
PrependScheme::First => {
|
||||
if !normalized.get().starts_with(self.replacement)
|
||||
&& normalized.offsets_original().0 == 0
|
||||
{
|
||||
normalized.prepend(&self.str_rep);
|
||||
}
|
||||
}
|
||||
PrependScheme::Never => {}
|
||||
};
|
||||
if self.split {
|
||||
normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
|
||||
} else {
|
||||
first_split = false;
|
||||
Ok(vec![normalized])
|
||||
}
|
||||
|
||||
normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -139,7 +151,7 @@ impl Decoder for Metaspace {
|
||||
.chars()
|
||||
.flat_map(|c| {
|
||||
if c == self.replacement {
|
||||
if i == 0 && self.add_prefix_space {
|
||||
if i == 0 && self.prepend_scheme != PrependScheme::Never {
|
||||
None
|
||||
} else {
|
||||
Some(' ')
|
||||
@ -163,8 +175,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn serialization() {
|
||||
let metaspace = Metaspace::new('_', true);
|
||||
let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
|
||||
let metaspace = Metaspace::new('_', PrependScheme::Always, true);
|
||||
let metaspace_s =
|
||||
r#"{"type":"Metaspace","replacement":"_","prepend_scheme":"always","split":true}"#;
|
||||
assert_eq!(serde_json::to_string(&metaspace).unwrap(), metaspace_s);
|
||||
assert_eq!(
|
||||
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
|
||||
@ -172,7 +185,10 @@ mod tests {
|
||||
);
|
||||
|
||||
// Also check it can deserialize previous versions
|
||||
let metaspace = Metaspace::new('_', true);
|
||||
let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":false,"prepend_scheme":"always"}"#;
|
||||
assert!(serde_json::from_str::<Metaspace>(metaspace_s).is_err(),);
|
||||
|
||||
let metaspace = Metaspace::new('_', PrependScheme::Always, true);
|
||||
let metaspace_s = r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
|
||||
assert_eq!(
|
||||
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
|
||||
@ -188,7 +204,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn basic() {
|
||||
let pretok = Metaspace::new('▁', true);
|
||||
let pretok = Metaspace::new('▁', PrependScheme::Always, true);
|
||||
let mut pretokenized = PreTokenizedString::from("Hey friend!");
|
||||
pretok.pre_tokenize(&mut pretokenized).unwrap();
|
||||
assert_eq!(
|
||||
@ -211,7 +227,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn multiple_spaces() {
|
||||
let pretok = Metaspace::new('▁', true);
|
||||
let pretok = Metaspace::new('▁', PrependScheme::Always, true);
|
||||
let mut pretokenized = PreTokenizedString::from("Hey friend!");
|
||||
pretok.pre_tokenize(&mut pretokenized).unwrap();
|
||||
assert_eq!(
|
||||
@ -244,30 +260,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn non_legacy_meta_space() {
|
||||
assert_eq!(
|
||||
Metaspace::new('▁', true),
|
||||
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
|
||||
);
|
||||
|
||||
let mut pretok = Metaspace::new('▁', true);
|
||||
let mut pretok = Metaspace::new('▁', PrependScheme::Always, true);
|
||||
pretok.set_prepend_scheme(PrependScheme::Always);
|
||||
assert_eq!(
|
||||
pretok,
|
||||
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
|
||||
);
|
||||
assert_eq!(pretok, Metaspace::new('▁', PrependScheme::Always, true));
|
||||
|
||||
pretok.set_prepend_scheme(PrependScheme::Never);
|
||||
assert_eq!(
|
||||
pretok,
|
||||
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Never)
|
||||
);
|
||||
assert_eq!(pretok, Metaspace::new('▁', PrependScheme::Never, true));
|
||||
|
||||
pretok.set_prepend_scheme(PrependScheme::First);
|
||||
assert_eq!(
|
||||
pretok,
|
||||
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::First)
|
||||
);
|
||||
assert_eq!(pretok, Metaspace::new('▁', PrependScheme::First, true));
|
||||
|
||||
let pretok = Metaspace::new('▁', PrependScheme::First, false);
|
||||
let mut pretokenized = PreTokenizedString::from("Hey my friend <s>how▁are you");
|
||||
let re_ref = Regex::new(r"(<s>)").unwrap();
|
||||
pretokenized
|
||||
@ -282,17 +285,12 @@ mod tests {
|
||||
.map(|(s, o, _)| (s, o))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("▁Hey", (0, 6)),
|
||||
("▁my", (6, 11)),
|
||||
("▁friend", (11, 20)),
|
||||
("▁", (20, 23)),
|
||||
("▁Hey▁my▁friend▁", (0, 23)),
|
||||
("<s>", (23, 26)),
|
||||
("how", (26, 29)),
|
||||
("▁are", (29, 35)),
|
||||
("▁you", (35, 41))
|
||||
("how▁are▁you", (26, 41))
|
||||
]
|
||||
);
|
||||
pretok.set_prepend_scheme(PrependScheme::Always);
|
||||
let pretok = Metaspace::new('▁', PrependScheme::Always, true);
|
||||
pretok.pre_tokenize(&mut pretokenized).unwrap();
|
||||
assert_eq!(
|
||||
pretokenized
|
||||
@ -312,7 +310,7 @@ mod tests {
|
||||
]
|
||||
);
|
||||
|
||||
pretok.set_prepend_scheme(PrependScheme::First);
|
||||
let pretok = Metaspace::new('▁', PrependScheme::First, false);
|
||||
let mut pretokenized = PreTokenizedString::from(" Hey <s>how"); // test with prefix
|
||||
pretokenized
|
||||
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
|
||||
@ -324,12 +322,7 @@ mod tests {
|
||||
.into_iter()
|
||||
.map(|(s, o, _)| (s, o))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("▁Hey", (0, 6)),
|
||||
("▁", (6, 9)),
|
||||
("<s>", (9, 12)),
|
||||
("how", (12, 15))
|
||||
]
|
||||
vec![("▁Hey▁", (0, 9)), ("<s>", (9, 12)), ("how", (12, 15))]
|
||||
);
|
||||
|
||||
let mut pretokenized = PreTokenizedString::from(" Hey <s>how <s>are <s> you"); // test with many splits
|
||||
@ -344,14 +337,11 @@ mod tests {
|
||||
.map(|(s, o, _)| (s, o))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("▁Hey", (0, 6)),
|
||||
("▁", (6, 9)),
|
||||
("▁Hey▁", (0, 9)),
|
||||
("<s>", (9, 12)),
|
||||
("how", (12, 15)),
|
||||
("▁", (15, 18)),
|
||||
("how▁", (12, 18)),
|
||||
("<s>", (18, 21)),
|
||||
("are", (21, 24)),
|
||||
("▁", (24, 27)),
|
||||
("are▁", (21, 27)),
|
||||
("<s>", (27, 30)),
|
||||
("▁you", (30, 36))
|
||||
]
|
||||
@ -359,10 +349,16 @@ mod tests {
|
||||
}
|
||||
#[test]
|
||||
fn decode() {
|
||||
let decoder = Metaspace::new('▁', true);
|
||||
let decoder = Metaspace::new('▁', PrependScheme::Always, true);
|
||||
let res = decoder
|
||||
.decode_chain(vec!["▁Hey".into(), "▁friend!".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["Hey", " friend!"])
|
||||
assert_eq!(res, vec!["Hey", " friend!"]);
|
||||
|
||||
let decoder = Metaspace::new('▁', PrependScheme::Never, true);
|
||||
let res = decoder
|
||||
.decode_chain(vec!["▁Hey".into(), "▁friend!".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec![" Hey", " friend!"]);
|
||||
}
|
||||
}
|
||||
|
@ -71,6 +71,7 @@ impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::metaspace::PrependScheme;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@ -81,7 +82,7 @@ mod tests {
|
||||
pre_tokenizer,
|
||||
PreTokenizerWrapper::Sequence(Sequence::new(vec![
|
||||
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
|
||||
]))
|
||||
);
|
||||
|
||||
@ -92,7 +93,7 @@ mod tests {
|
||||
|
||||
assert_eq!(
|
||||
pre_tokenizer,
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
|
||||
);
|
||||
|
||||
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#).unwrap();
|
||||
@ -101,7 +102,7 @@ mod tests {
|
||||
pre_tokenizer,
|
||||
PreTokenizerWrapper::Sequence(Sequence::new(vec![
|
||||
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
|
||||
]))
|
||||
);
|
||||
|
||||
@ -112,10 +113,10 @@ mod tests {
|
||||
|
||||
assert_eq!(
|
||||
pre_tokenizer,
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme(
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new(
|
||||
'▁',
|
||||
true,
|
||||
metaspace::PrependScheme::First
|
||||
metaspace::PrependScheme::First,
|
||||
true
|
||||
))
|
||||
);
|
||||
|
||||
@ -126,10 +127,10 @@ mod tests {
|
||||
|
||||
assert_eq!(
|
||||
pre_tokenizer,
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme(
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new(
|
||||
'▁',
|
||||
true,
|
||||
metaspace::PrependScheme::Always
|
||||
metaspace::PrependScheme::Always,
|
||||
true
|
||||
))
|
||||
);
|
||||
}
|
||||
|
@ -448,7 +448,7 @@ impl TemplateProcessingBuilder {
|
||||
}
|
||||
};
|
||||
|
||||
let empty = vec![];
|
||||
let empty = [];
|
||||
let missing: HashSet<&str> = self
|
||||
.single
|
||||
.as_ref()
|
||||
|
@ -928,7 +928,6 @@ where
|
||||
pretokenized: P,
|
||||
) -> Result<PreTokenizedString> {
|
||||
let mut pretokenized: PreTokenizedString = pretokenized.into();
|
||||
|
||||
if let Some(ref pretok) = self.pre_tokenizer {
|
||||
pretok.pre_tokenize(&mut pretokenized)?;
|
||||
}
|
||||
|
Reference in New Issue
Block a user