[pre_tokenizers] Fix sentencepiece based Metaspace (#1357)

* nits

* allow for legacy beahaviour without making any breaking changes

* add a todo

* set to legacy by default

* skip legacy serialization

* push correct update

* lint

* add deserialization test

* add a python test as well

* updates

* fix serialization tests

* nits

* python stylijng of the tests

* better tests

* fix offsets

* fix imports

* fmt

* update metaspace

* remove TODO

* use enm

* fix some tses

* nits

* use enum

* update tests

* syling

* remove impl from for PrependScheme

* use simple getters and setters

* lint

* update tests

* add test new == new_with_prepend_scheme

* revert a change

* use setters and getterts

* Update bindings/python/src/pre_tokenizers.rs

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* nits

* use copy rather than ref

* nits format

* more nits

* allow option string

* enforce First Never Always camel cased

* nits

* refactor

* update test as well

* fmt

* nits

* properly error out

* Update bindings/python/src/pre_tokenizers.rs

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* suggestion changes

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Arthur
2023-11-14 18:05:07 +01:00
committed by GitHub
parent ee2af9e99a
commit f55822baea
5 changed files with 257 additions and 17 deletions

View File

@ -11,7 +11,7 @@ use tk::pre_tokenizers::bert::BertPreTokenizer;
use tk::pre_tokenizers::byte_level::ByteLevel; use tk::pre_tokenizers::byte_level::ByteLevel;
use tk::pre_tokenizers::delimiter::CharDelimiterSplit; use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
use tk::pre_tokenizers::digits::Digits; use tk::pre_tokenizers::digits::Digits;
use tk::pre_tokenizers::metaspace::Metaspace; use tk::pre_tokenizers::metaspace::{Metaspace, PrependScheme};
use tk::pre_tokenizers::punctuation::Punctuation; use tk::pre_tokenizers::punctuation::Punctuation;
use tk::pre_tokenizers::split::Split; use tk::pre_tokenizers::split::Split;
use tk::pre_tokenizers::unicode_scripts::UnicodeScripts; use tk::pre_tokenizers::unicode_scripts::UnicodeScripts;
@ -452,6 +452,21 @@ impl PySequence {
} }
} }
fn from_string(string: String) -> Result<PrependScheme, PyErr> {
let scheme = match string.as_str() {
"first" => PrependScheme::First,
"never" => PrependScheme::Never,
"always" => PrependScheme::Always,
_ => {
return Err(exceptions::PyValueError::new_err(format!(
"{} is an unknown variant, should be one of ['first', 'never', 'always']",
string
)));
}
};
Ok(scheme)
}
/// Metaspace pre-tokenizer /// Metaspace pre-tokenizer
/// ///
/// This pre-tokenizer replaces any whitespace by the provided replacement character. /// This pre-tokenizer replaces any whitespace by the provided replacement character.
@ -489,17 +504,44 @@ impl PyMetaspace {
setter!(self_, Metaspace, add_prefix_space, add_prefix_space); setter!(self_, Metaspace, add_prefix_space, add_prefix_space);
} }
#[getter]
fn get_prepend_scheme(self_: PyRef<Self>) -> String {
// Assuming Metaspace has a method to get the prepend_scheme as a string
let scheme: PrependScheme = getter!(self_, Metaspace, get_prepend_scheme());
match scheme {
PrependScheme::First => "first",
PrependScheme::Never => "never",
PrependScheme::Always => "always",
}
.to_string()
}
#[setter]
fn set_prepend_scheme(self_: PyRef<Self>, prepend_scheme: String) -> PyResult<()> {
let scheme = from_string(prepend_scheme)?;
setter!(self_, Metaspace, @set_prepend_scheme, scheme);
Ok(())
}
#[new] #[new]
#[pyo3(signature = (replacement = PyChar('▁'), add_prefix_space = true, **_kwargs), text_signature = "(self, replacement=\"_\", add_prefix_space=True)")] #[pyo3(signature = (replacement = PyChar('▁'), add_prefix_space = true, prepend_scheme=None, **_kwargs), text_signature = "(self, replacement=\"_\", add_prefix_space=True)")]
fn new( fn new(
replacement: PyChar, replacement: PyChar,
add_prefix_space: bool, add_prefix_space: bool,
prepend_scheme: Option<String>,
_kwargs: Option<&PyDict>, _kwargs: Option<&PyDict>,
) -> (Self, PyPreTokenizer) { ) -> PyResult<(Self, PyPreTokenizer)> {
( // Create a new Metaspace instance
PyMetaspace {}, let mut new_instance: Metaspace = Metaspace::new(replacement.0, add_prefix_space);
Metaspace::new(replacement.0, add_prefix_space).into(),
) // If a prepend scheme is provided, set it
if let Some(prepend_scheme) = prepend_scheme {
match from_string(prepend_scheme) {
Ok(prepend_scheme_enum) => new_instance.set_prepend_scheme(prepend_scheme_enum),
Err(err) => return Err(err),
}
}
Ok((PyMetaspace {}, new_instance.into()))
} }
} }

View File

@ -110,6 +110,8 @@ class TestMetaspace:
assert pretok.replacement == "%" assert pretok.replacement == "%"
pretok.add_prefix_space = True pretok.add_prefix_space = True
assert pretok.add_prefix_space == True assert pretok.add_prefix_space == True
pretok.prepend_scheme = "never"
assert pretok.prepend_scheme == "never"
class TestCharDelimiterSplit: class TestCharDelimiterSplit:

View File

@ -73,15 +73,14 @@ mod tests {
#[test] #[test]
fn decoder_serialization() { fn decoder_serialization() {
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#; let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap(); let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
let serialized = serde_json::to_string(&decoder).unwrap(); let serialized = serde_json::to_string(&decoder).unwrap();
assert_eq!(serialized, json); assert_eq!(serialized, json);
} }
#[test] #[test]
fn decoder_serialization_other_no_arg() { fn decoder_serialization_other_no_arg() {
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#; let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap(); let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
let serialized = serde_json::to_string(&decoder).unwrap(); let serialized = serde_json::to_string(&decoder).unwrap();
assert_eq!(serialized, json); assert_eq!(serialized, json);
@ -89,7 +88,7 @@ mod tests {
#[test] #[test]
fn decoder_serialization_no_decode() { fn decoder_serialization_no_decode() {
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#; let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
assert!(serde_json::from_str::<DecoderWrapper>(json).is_err()); assert!(serde_json::from_str::<DecoderWrapper>(json).is_err());
} }
} }

View File

@ -1,6 +1,17 @@
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; /// Enum representing options for the metaspace prepending scheme.
#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)]
#[serde(rename_all = "snake_case")]
pub enum PrependScheme {
/// Specifies that the scheme should be prepended only once, on the first split.
First,
/// Specifies that the space should not be prepended.
Never,
/// Specifies that the scheme should always be prepended.
Always,
}
#[derive(Debug, Clone, PartialEq, Serialize, Eq)] #[derive(Debug, Clone, PartialEq, Serialize, Eq)]
/// Replaces all the whitespaces by the provided meta character and then /// Replaces all the whitespaces by the provided meta character and then
@ -9,6 +20,7 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitD
pub struct Metaspace { pub struct Metaspace {
replacement: char, replacement: char,
pub add_prefix_space: bool, pub add_prefix_space: bool,
pub prepend_scheme: PrependScheme,
#[serde(skip)] #[serde(skip)]
str_rep: String, str_rep: String,
} }
@ -23,27 +35,51 @@ impl<'de> Deserialize<'de> for Metaspace {
Metaspace, Metaspace,
} }
fn default_prepend_scheme_value() -> PrependScheme {
PrependScheme::Always
}
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct MetaspaceHelper { pub struct MetaspaceHelper {
#[serde(rename = "type")] #[serde(rename = "type")]
_type: Type, _type: Type,
replacement: char, replacement: char,
pub add_prefix_space: bool, pub add_prefix_space: bool,
#[serde(default = "default_prepend_scheme_value")]
pub prepend_scheme: PrependScheme,
#[serde(skip, rename = "str_rep")] #[serde(skip, rename = "str_rep")]
_str_rep: String, _str_rep: String,
} }
let helper = MetaspaceHelper::deserialize(deserializer)?; let helper = MetaspaceHelper::deserialize(deserializer)?;
Ok(Self::new(helper.replacement, helper.add_prefix_space)) let instance = Self::new_with_prepend_scheme(
helper.replacement,
helper.add_prefix_space,
helper.prepend_scheme,
);
Ok(instance)
} }
} }
impl Metaspace { impl Metaspace {
pub fn new(replacement: char, add_prefix_space: bool) -> Self { pub fn new(replacement: char, add_prefix_space: bool) -> Self {
Self::new_with_prepend_scheme(
replacement,
add_prefix_space,
PrependScheme::Always, // always prepend for legacy purpose
)
}
pub fn new_with_prepend_scheme(
replacement: char,
add_prefix_space: bool,
prepend_scheme: PrependScheme,
) -> Self {
Self { Self {
replacement, replacement,
str_rep: replacement.to_string(), str_rep: replacement.to_string(),
add_prefix_space, add_prefix_space,
prepend_scheme,
} }
} }
@ -55,6 +91,14 @@ impl Metaspace {
self.replacement = replacement; self.replacement = replacement;
self.str_rep = replacement.to_string(); self.str_rep = replacement.to_string();
} }
pub fn get_prepend_scheme(&self) -> PrependScheme {
self.prepend_scheme
}
pub fn set_prepend_scheme(&mut self, scheme: PrependScheme) {
self.prepend_scheme = scheme;
}
} }
impl Default for Metaspace { impl Default for Metaspace {
@ -65,10 +109,19 @@ impl Default for Metaspace {
impl PreTokenizer for Metaspace { impl PreTokenizer for Metaspace {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
let mut first_split = true;
pretokenized.split(|_, mut normalized| { pretokenized.split(|_, mut normalized| {
normalized.replace(' ', &self.str_rep)?; normalized.replace(' ', &self.str_rep)?;
if self.add_prefix_space && !normalized.get().starts_with(self.replacement) { if self.add_prefix_space && !normalized.get().starts_with(self.replacement) {
normalized.prepend(&self.str_rep); if self.prepend_scheme == PrependScheme::Always {
normalized.prepend(&self.str_rep);
} else if self.prepend_scheme == PrependScheme::First && first_split {
normalized.prepend(&self.str_rep);
first_split = false;
}
} else {
first_split = false;
} }
normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext) normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
@ -103,13 +156,15 @@ impl Decoder for Metaspace {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use regex::Regex;
use super::*; use super::*;
use crate::{OffsetReferential, OffsetType}; use crate::{OffsetReferential, OffsetType};
#[test] #[test]
fn serialization() { fn serialization() {
let metaspace = Metaspace::new('_', true); let metaspace = Metaspace::new('_', true);
let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true}"#; let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
assert_eq!(serde_json::to_string(&metaspace).unwrap(), metaspace_s); assert_eq!(serde_json::to_string(&metaspace).unwrap(), metaspace_s);
assert_eq!( assert_eq!(
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(), serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
@ -118,8 +173,7 @@ mod tests {
// Also check it can deserialize previous versions // Also check it can deserialize previous versions
let metaspace = Metaspace::new('_', true); let metaspace = Metaspace::new('_', true);
let metaspace_s = let metaspace_s = r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true}"#;
assert_eq!( assert_eq!(
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(), serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
metaspace metaspace
@ -188,6 +242,121 @@ mod tests {
); );
} }
#[test]
fn non_legacy_meta_space() {
assert_eq!(
Metaspace::new('▁', true),
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
);
let mut pretok = Metaspace::new('▁', true);
pretok.set_prepend_scheme(PrependScheme::Always);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
);
pretok.set_prepend_scheme(PrependScheme::Never);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Never)
);
pretok.set_prepend_scheme(PrependScheme::First);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::First)
);
let mut pretokenized = PreTokenizedString::from("Hey my friend <s>how▁are you");
let re_ref = Regex::new(r"(<s>)").unwrap();
pretokenized
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
.expect("Bad split");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("▁my", (6, 11)),
("▁friend", (11, 20)),
("", (20, 23)),
("<s>", (23, 26)),
("how", (26, 29)),
("▁are", (29, 35)),
("▁you", (35, 41))
]
);
pretok.set_prepend_scheme(PrependScheme::Always);
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("▁my", (6, 11)),
("▁friend", (11, 20)),
("", (20, 23)),
("▁<s>", (23, 29)),
("▁how", (29, 35)),
("▁are", (35, 41)),
("▁you", (41, 47))
]
);
pretok.set_prepend_scheme(PrependScheme::First);
let mut pretokenized = PreTokenizedString::from(" Hey <s>how"); // test with prefix
pretokenized
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
.expect("Bad split");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("", (6, 9)),
("<s>", (9, 12)),
("how", (12, 15))
]
);
let mut pretokenized = PreTokenizedString::from(" Hey <s>how <s>are <s> you"); // test with many splits
pretokenized
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
.expect("Bad split");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("", (6, 9)),
("<s>", (9, 12)),
("how", (12, 15)),
("", (15, 18)),
("<s>", (18, 21)),
("are", (21, 24)),
("", (24, 27)),
("<s>", (27, 30)),
("▁you", (30, 36))
]
);
}
#[test] #[test]
fn decode() { fn decode() {
let decoder = Metaspace::new('▁', true); let decoder = Metaspace::new('▁', true);

View File

@ -104,6 +104,34 @@ mod tests {
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true)) PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
])) ]))
); );
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true, "prepend_scheme":"first"}"#,
)
.unwrap();
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme(
'▁',
true,
metaspace::PrependScheme::First
))
);
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true, "prepend_scheme":"always"}"#,
)
.unwrap();
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme(
'▁',
true,
metaspace::PrependScheme::Always
))
);
} }
#[test] #[test]