Refactor metaspace (#1476)

* version = "0.15.3-dev-0”

Improve performances of meta space, but also just fix it.

(transformers) ➜  transformers git:(refactor-default-llama) ✗ python ../scripts/gemma-dummy.py
Token indices sequence length is longer than the specified maximum sequence length for this model (14999 > 2048). Running this sequence through the model will result in indexing errors
['<REPR_END>', '▁inform', '<s>', '.', '▁Hey', '<unk>', '.', '▁', '▁', '▁', '▁', '▁', '▁', '▁.']
['▁inform', '<s>', '.', '▁Hey', '<unk>', '.', '▁', '▁', '▁', '▁', '▁', '▁', '▁.']
[0.0006330013275146484, 0.0014591217041015625, 0.015890836715698242, 0.18584918975830078, 2.1726326942443848]
(transformers) ➜  transformers git:(refactor-default-llama) ✗ python ../scripts/gemma-dummy.py
Token indices sequence length is longer than the specified maximum sequence length for this model (10000 > 2048). Running this sequence through the model will result in indexing errors
['<REPR_END>', 'in', 'form', '<s>', '.', '▁Hey', '<unk>', '.', '▁▁▁▁▁▁', '▁.']
['in', 'form', '<s>', '.', '▁Hey', '<unk>', '.', '▁▁▁▁▁▁', '▁.']
[0.0008409023284912109, 0.0008909702301025391, 0.00882411003112793, 0.10214710235595703, 1.187899112701416]

* well what do we have

* nit

* be BC with non legacy

* unrelated change for clippy

* fix test

* splitting is a must for word_ids

* fmt and lint

* Fixing everything (hopefully better).

* Fixing node.

* Including yarn.lock

* Lint.

* Stubs.

* revert to use split

* fix merge issues

* fix tests

* finish fixing tests

* ruff

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Arthur
2024-03-30 10:27:24 +01:00
committed by GitHub
parent 6153126b22
commit 09069717e9
21 changed files with 1672 additions and 1515 deletions

View File

@ -73,14 +73,19 @@ mod tests {
#[test]
fn decoder_serialization() {
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let oldjson = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let olddecoder: DecoderWrapper = serde_json::from_str(oldjson).unwrap();
let oldserialized = serde_json::to_string(&olddecoder).unwrap();
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always","split":true}]}"#;
assert_eq!(oldserialized, json);
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
let serialized = serde_json::to_string(&decoder).unwrap();
assert_eq!(serialized, json);
}
#[test]
fn decoder_serialization_other_no_arg() {
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always","split":true}]}"#;
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
let serialized = serde_json::to_string(&decoder).unwrap();
assert_eq!(serialized, json);
@ -88,7 +93,7 @@ mod tests {
#[test]
fn decoder_serialization_no_decode() {
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always"}]}"#;
assert!(serde_json::from_str::<DecoderWrapper>(json).is_err());
}
}

View File

@ -1,5 +1,5 @@
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
use serde::{Deserialize, Deserializer, Serialize};
use serde::{de, Deserialize, Deserializer, Serialize};
/// Enum representing options for the metaspace prepending scheme.
#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)]
@ -19,8 +19,8 @@ pub enum PrependScheme {
#[serde(tag = "type")]
pub struct Metaspace {
replacement: char,
pub add_prefix_space: bool,
pub prepend_scheme: PrependScheme,
pub split: bool,
#[serde(skip)]
str_rep: String,
}
@ -44,42 +44,40 @@ impl<'de> Deserialize<'de> for Metaspace {
#[serde(rename = "type")]
_type: Type,
replacement: char,
pub add_prefix_space: bool,
pub add_prefix_space: Option<bool>,
#[serde(default = "default_prepend_scheme_value")]
pub prepend_scheme: PrependScheme,
#[serde(skip, rename = "str_rep")]
_str_rep: String,
pub split: Option<bool>,
#[serde(rename = "str_rep")]
_str_rep: Option<String>,
}
let helper = MetaspaceHelper::deserialize(deserializer)?;
let instance = Self::new_with_prepend_scheme(
let mut helper = MetaspaceHelper::deserialize(deserializer)?;
if let Some(false) = helper.add_prefix_space {
if helper.prepend_scheme != PrependScheme::Never {
return Err(de::Error::custom(
"add_prefix_space does not match declared prepend_scheme",
));
}
helper.prepend_scheme = PrependScheme::Never;
}
let instance = Self::new(
helper.replacement,
helper.add_prefix_space,
helper.prepend_scheme,
helper.split.unwrap_or(true),
);
Ok(instance)
}
}
impl Metaspace {
pub fn new(replacement: char, add_prefix_space: bool) -> Self {
Self::new_with_prepend_scheme(
replacement,
add_prefix_space,
PrependScheme::Always, // always prepend for legacy purpose
)
}
pub fn new_with_prepend_scheme(
replacement: char,
add_prefix_space: bool,
prepend_scheme: PrependScheme,
) -> Self {
pub fn new(replacement: char, prepend_scheme: PrependScheme, split: bool) -> Self {
Self {
replacement,
str_rep: replacement.to_string(),
add_prefix_space,
prepend_scheme,
split,
}
}
@ -92,6 +90,14 @@ impl Metaspace {
self.str_rep = replacement.to_string();
}
pub fn get_split(&self) -> bool {
self.split
}
pub fn set_split(&mut self, split: bool) {
self.split = split;
}
pub fn get_prepend_scheme(&self) -> PrependScheme {
self.prepend_scheme
}
@ -103,28 +109,34 @@ impl Metaspace {
impl Default for Metaspace {
fn default() -> Self {
Self::new('▁', true)
Self::new('▁', PrependScheme::Always, true)
}
}
impl PreTokenizer for Metaspace {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
let mut first_split = true;
pretokenized.split(|_, mut normalized| {
normalized.replace(' ', &self.str_rep)?;
if self.add_prefix_space && !normalized.get().starts_with(self.replacement) {
if self.prepend_scheme == PrependScheme::Always {
normalized.prepend(&self.str_rep);
} else if self.prepend_scheme == PrependScheme::First && first_split {
normalized.prepend(&self.str_rep);
first_split = false;
match self.prepend_scheme {
PrependScheme::Always => {
if !normalized.get().starts_with(self.replacement) {
normalized.prepend(&self.str_rep);
}
}
PrependScheme::First => {
if !normalized.get().starts_with(self.replacement)
&& normalized.offsets_original().0 == 0
{
normalized.prepend(&self.str_rep);
}
}
PrependScheme::Never => {}
};
if self.split {
normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
} else {
first_split = false;
Ok(vec![normalized])
}
normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
})
}
}
@ -139,7 +151,7 @@ impl Decoder for Metaspace {
.chars()
.flat_map(|c| {
if c == self.replacement {
if i == 0 && self.add_prefix_space {
if i == 0 && self.prepend_scheme != PrependScheme::Never {
None
} else {
Some(' ')
@ -163,8 +175,9 @@ mod tests {
#[test]
fn serialization() {
let metaspace = Metaspace::new('_', true);
let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
let metaspace = Metaspace::new('_', PrependScheme::Always, true);
let metaspace_s =
r#"{"type":"Metaspace","replacement":"_","prepend_scheme":"always","split":true}"#;
assert_eq!(serde_json::to_string(&metaspace).unwrap(), metaspace_s);
assert_eq!(
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
@ -172,7 +185,10 @@ mod tests {
);
// Also check it can deserialize previous versions
let metaspace = Metaspace::new('_', true);
let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":false,"prepend_scheme":"always"}"#;
assert!(serde_json::from_str::<Metaspace>(metaspace_s).is_err(),);
let metaspace = Metaspace::new('_', PrependScheme::Always, true);
let metaspace_s = r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
assert_eq!(
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
@ -188,7 +204,7 @@ mod tests {
#[test]
fn basic() {
let pretok = Metaspace::new('▁', true);
let pretok = Metaspace::new('▁', PrependScheme::Always, true);
let mut pretokenized = PreTokenizedString::from("Hey friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
@ -211,7 +227,7 @@ mod tests {
#[test]
fn multiple_spaces() {
let pretok = Metaspace::new('▁', true);
let pretok = Metaspace::new('▁', PrependScheme::Always, true);
let mut pretokenized = PreTokenizedString::from("Hey friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
@ -244,30 +260,17 @@ mod tests {
#[test]
fn non_legacy_meta_space() {
assert_eq!(
Metaspace::new('▁', true),
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
);
let mut pretok = Metaspace::new('▁', true);
let mut pretok = Metaspace::new('▁', PrependScheme::Always, true);
pretok.set_prepend_scheme(PrependScheme::Always);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
);
assert_eq!(pretok, Metaspace::new('▁', PrependScheme::Always, true));
pretok.set_prepend_scheme(PrependScheme::Never);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Never)
);
assert_eq!(pretok, Metaspace::new('▁', PrependScheme::Never, true));
pretok.set_prepend_scheme(PrependScheme::First);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::First)
);
assert_eq!(pretok, Metaspace::new('▁', PrependScheme::First, true));
let pretok = Metaspace::new('▁', PrependScheme::First, false);
let mut pretokenized = PreTokenizedString::from("Hey my friend <s>how▁are you");
let re_ref = Regex::new(r"(<s>)").unwrap();
pretokenized
@ -282,17 +285,12 @@ mod tests {
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("▁my", (6, 11)),
("▁friend", (11, 20)),
("", (20, 23)),
("▁Hey▁my▁friend▁", (0, 23)),
("<s>", (23, 26)),
("how", (26, 29)),
("▁are", (29, 35)),
("▁you", (35, 41))
("how▁are▁you", (26, 41))
]
);
pretok.set_prepend_scheme(PrependScheme::Always);
let pretok = Metaspace::new('▁', PrependScheme::Always, true);
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
@ -312,7 +310,7 @@ mod tests {
]
);
pretok.set_prepend_scheme(PrependScheme::First);
let pretok = Metaspace::new('▁', PrependScheme::First, false);
let mut pretokenized = PreTokenizedString::from(" Hey <s>how"); // test with prefix
pretokenized
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
@ -324,12 +322,7 @@ mod tests {
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("", (6, 9)),
("<s>", (9, 12)),
("how", (12, 15))
]
vec![("▁Hey▁", (0, 9)), ("<s>", (9, 12)), ("how", (12, 15))]
);
let mut pretokenized = PreTokenizedString::from(" Hey <s>how <s>are <s> you"); // test with many splits
@ -344,14 +337,11 @@ mod tests {
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("", (6, 9)),
("▁Hey", (0, 9)),
("<s>", (9, 12)),
("how", (12, 15)),
("", (15, 18)),
("how", (12, 18)),
("<s>", (18, 21)),
("are", (21, 24)),
("", (24, 27)),
("are", (21, 27)),
("<s>", (27, 30)),
("▁you", (30, 36))
]
@ -359,10 +349,16 @@ mod tests {
}
#[test]
fn decode() {
let decoder = Metaspace::new('▁', true);
let decoder = Metaspace::new('▁', PrependScheme::Always, true);
let res = decoder
.decode_chain(vec!["▁Hey".into(), "▁friend!".into()])
.unwrap();
assert_eq!(res, vec!["Hey", " friend!"])
assert_eq!(res, vec!["Hey", " friend!"]);
let decoder = Metaspace::new('▁', PrependScheme::Never, true);
let res = decoder
.decode_chain(vec!["▁Hey".into(), "▁friend!".into()])
.unwrap();
assert_eq!(res, vec![" Hey", " friend!"]);
}
}

View File

@ -71,6 +71,7 @@ impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts);
#[cfg(test)]
mod tests {
use super::metaspace::PrependScheme;
use super::*;
#[test]
@ -81,7 +82,7 @@ mod tests {
pre_tokenizer,
PreTokenizerWrapper::Sequence(Sequence::new(vec![
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
]))
);
@ -92,7 +93,7 @@ mod tests {
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
);
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#).unwrap();
@ -101,7 +102,7 @@ mod tests {
pre_tokenizer,
PreTokenizerWrapper::Sequence(Sequence::new(vec![
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
]))
);
@ -112,10 +113,10 @@ mod tests {
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme(
PreTokenizerWrapper::Metaspace(Metaspace::new(
'▁',
true,
metaspace::PrependScheme::First
metaspace::PrependScheme::First,
true
))
);
@ -126,10 +127,10 @@ mod tests {
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme(
PreTokenizerWrapper::Metaspace(Metaspace::new(
'▁',
true,
metaspace::PrependScheme::Always
metaspace::PrependScheme::Always,
true
))
);
}

View File

@ -448,7 +448,7 @@ impl TemplateProcessingBuilder {
}
};
let empty = vec![];
let empty = [];
let missing: HashSet<&str> = self
.single
.as_ref()

View File

@ -928,7 +928,6 @@ where
pretokenized: P,
) -> Result<PreTokenizedString> {
let mut pretokenized: PreTokenizedString = pretokenized.into();
if let Some(ref pretok) = self.pre_tokenizer {
pretok.pre_tokenize(&mut pretokenized)?;
}