mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-01 23:09:34 +00:00
[pre_tokenizers
] Fix sentencepiece based Metaspace (#1357)
* nits * allow for legacy beahaviour without making any breaking changes * add a todo * set to legacy by default * skip legacy serialization * push correct update * lint * add deserialization test * add a python test as well * updates * fix serialization tests * nits * python stylijng of the tests * better tests * fix offsets * fix imports * fmt * update metaspace * remove TODO * use enm * fix some tses * nits * use enum * update tests * syling * remove impl from for PrependScheme * use simple getters and setters * lint * update tests * add test new == new_with_prepend_scheme * revert a change * use setters and getterts * Update bindings/python/src/pre_tokenizers.rs Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * nits * use copy rather than ref * nits format * more nits * allow option string * enforce First Never Always camel cased * nits * refactor * update test as well * fmt * nits * properly error out * Update bindings/python/src/pre_tokenizers.rs Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * suggestion changes --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@ -11,7 +11,7 @@ use tk::pre_tokenizers::bert::BertPreTokenizer;
|
||||
use tk::pre_tokenizers::byte_level::ByteLevel;
|
||||
use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
|
||||
use tk::pre_tokenizers::digits::Digits;
|
||||
use tk::pre_tokenizers::metaspace::Metaspace;
|
||||
use tk::pre_tokenizers::metaspace::{Metaspace, PrependScheme};
|
||||
use tk::pre_tokenizers::punctuation::Punctuation;
|
||||
use tk::pre_tokenizers::split::Split;
|
||||
use tk::pre_tokenizers::unicode_scripts::UnicodeScripts;
|
||||
@ -452,6 +452,21 @@ impl PySequence {
|
||||
}
|
||||
}
|
||||
|
||||
fn from_string(string: String) -> Result<PrependScheme, PyErr> {
|
||||
let scheme = match string.as_str() {
|
||||
"first" => PrependScheme::First,
|
||||
"never" => PrependScheme::Never,
|
||||
"always" => PrependScheme::Always,
|
||||
_ => {
|
||||
return Err(exceptions::PyValueError::new_err(format!(
|
||||
"{} is an unknown variant, should be one of ['first', 'never', 'always']",
|
||||
string
|
||||
)));
|
||||
}
|
||||
};
|
||||
Ok(scheme)
|
||||
}
|
||||
|
||||
/// Metaspace pre-tokenizer
|
||||
///
|
||||
/// This pre-tokenizer replaces any whitespace by the provided replacement character.
|
||||
@ -489,17 +504,44 @@ impl PyMetaspace {
|
||||
setter!(self_, Metaspace, add_prefix_space, add_prefix_space);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_prepend_scheme(self_: PyRef<Self>) -> String {
|
||||
// Assuming Metaspace has a method to get the prepend_scheme as a string
|
||||
let scheme: PrependScheme = getter!(self_, Metaspace, get_prepend_scheme());
|
||||
match scheme {
|
||||
PrependScheme::First => "first",
|
||||
PrependScheme::Never => "never",
|
||||
PrependScheme::Always => "always",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_prepend_scheme(self_: PyRef<Self>, prepend_scheme: String) -> PyResult<()> {
|
||||
let scheme = from_string(prepend_scheme)?;
|
||||
setter!(self_, Metaspace, @set_prepend_scheme, scheme);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[pyo3(signature = (replacement = PyChar('▁'), add_prefix_space = true, **_kwargs), text_signature = "(self, replacement=\"_\", add_prefix_space=True)")]
|
||||
#[pyo3(signature = (replacement = PyChar('▁'), add_prefix_space = true, prepend_scheme=None, **_kwargs), text_signature = "(self, replacement=\"_\", add_prefix_space=True)")]
|
||||
fn new(
|
||||
replacement: PyChar,
|
||||
add_prefix_space: bool,
|
||||
prepend_scheme: Option<String>,
|
||||
_kwargs: Option<&PyDict>,
|
||||
) -> (Self, PyPreTokenizer) {
|
||||
(
|
||||
PyMetaspace {},
|
||||
Metaspace::new(replacement.0, add_prefix_space).into(),
|
||||
)
|
||||
) -> PyResult<(Self, PyPreTokenizer)> {
|
||||
// Create a new Metaspace instance
|
||||
let mut new_instance: Metaspace = Metaspace::new(replacement.0, add_prefix_space);
|
||||
|
||||
// If a prepend scheme is provided, set it
|
||||
if let Some(prepend_scheme) = prepend_scheme {
|
||||
match from_string(prepend_scheme) {
|
||||
Ok(prepend_scheme_enum) => new_instance.set_prepend_scheme(prepend_scheme_enum),
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
}
|
||||
Ok((PyMetaspace {}, new_instance.into()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,6 +110,8 @@ class TestMetaspace:
|
||||
assert pretok.replacement == "%"
|
||||
pretok.add_prefix_space = True
|
||||
assert pretok.add_prefix_space == True
|
||||
pretok.prepend_scheme = "never"
|
||||
assert pretok.prepend_scheme == "never"
|
||||
|
||||
|
||||
class TestCharDelimiterSplit:
|
||||
|
@ -73,15 +73,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn decoder_serialization() {
|
||||
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
|
||||
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
|
||||
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
|
||||
let serialized = serde_json::to_string(&decoder).unwrap();
|
||||
assert_eq!(serialized, json);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoder_serialization_other_no_arg() {
|
||||
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
|
||||
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
|
||||
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
|
||||
let serialized = serde_json::to_string(&decoder).unwrap();
|
||||
assert_eq!(serialized, json);
|
||||
@ -89,7 +88,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn decoder_serialization_no_decode() {
|
||||
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
|
||||
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
|
||||
assert!(serde_json::from_str::<DecoderWrapper>(json).is_err());
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,17 @@
|
||||
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||
/// Enum representing options for the metaspace prepending scheme.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PrependScheme {
|
||||
/// Specifies that the scheme should be prepended only once, on the first split.
|
||||
First,
|
||||
/// Specifies that the space should not be prepended.
|
||||
Never,
|
||||
/// Specifies that the scheme should always be prepended.
|
||||
Always,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Eq)]
|
||||
/// Replaces all the whitespaces by the provided meta character and then
|
||||
@ -9,6 +20,7 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitD
|
||||
pub struct Metaspace {
|
||||
replacement: char,
|
||||
pub add_prefix_space: bool,
|
||||
pub prepend_scheme: PrependScheme,
|
||||
#[serde(skip)]
|
||||
str_rep: String,
|
||||
}
|
||||
@ -23,27 +35,51 @@ impl<'de> Deserialize<'de> for Metaspace {
|
||||
Metaspace,
|
||||
}
|
||||
|
||||
fn default_prepend_scheme_value() -> PrependScheme {
|
||||
PrependScheme::Always
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct MetaspaceHelper {
|
||||
#[serde(rename = "type")]
|
||||
_type: Type,
|
||||
replacement: char,
|
||||
pub add_prefix_space: bool,
|
||||
#[serde(default = "default_prepend_scheme_value")]
|
||||
pub prepend_scheme: PrependScheme,
|
||||
#[serde(skip, rename = "str_rep")]
|
||||
_str_rep: String,
|
||||
}
|
||||
|
||||
let helper = MetaspaceHelper::deserialize(deserializer)?;
|
||||
Ok(Self::new(helper.replacement, helper.add_prefix_space))
|
||||
let instance = Self::new_with_prepend_scheme(
|
||||
helper.replacement,
|
||||
helper.add_prefix_space,
|
||||
helper.prepend_scheme,
|
||||
);
|
||||
Ok(instance)
|
||||
}
|
||||
}
|
||||
|
||||
impl Metaspace {
|
||||
pub fn new(replacement: char, add_prefix_space: bool) -> Self {
|
||||
Self::new_with_prepend_scheme(
|
||||
replacement,
|
||||
add_prefix_space,
|
||||
PrependScheme::Always, // always prepend for legacy purpose
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_with_prepend_scheme(
|
||||
replacement: char,
|
||||
add_prefix_space: bool,
|
||||
prepend_scheme: PrependScheme,
|
||||
) -> Self {
|
||||
Self {
|
||||
replacement,
|
||||
str_rep: replacement.to_string(),
|
||||
add_prefix_space,
|
||||
prepend_scheme,
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,6 +91,14 @@ impl Metaspace {
|
||||
self.replacement = replacement;
|
||||
self.str_rep = replacement.to_string();
|
||||
}
|
||||
|
||||
pub fn get_prepend_scheme(&self) -> PrependScheme {
|
||||
self.prepend_scheme
|
||||
}
|
||||
|
||||
pub fn set_prepend_scheme(&mut self, scheme: PrependScheme) {
|
||||
self.prepend_scheme = scheme;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Metaspace {
|
||||
@ -65,10 +109,19 @@ impl Default for Metaspace {
|
||||
|
||||
impl PreTokenizer for Metaspace {
|
||||
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
|
||||
let mut first_split = true;
|
||||
|
||||
pretokenized.split(|_, mut normalized| {
|
||||
normalized.replace(' ', &self.str_rep)?;
|
||||
if self.add_prefix_space && !normalized.get().starts_with(self.replacement) {
|
||||
if self.prepend_scheme == PrependScheme::Always {
|
||||
normalized.prepend(&self.str_rep);
|
||||
} else if self.prepend_scheme == PrependScheme::First && first_split {
|
||||
normalized.prepend(&self.str_rep);
|
||||
first_split = false;
|
||||
}
|
||||
} else {
|
||||
first_split = false;
|
||||
}
|
||||
|
||||
normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
|
||||
@ -103,13 +156,15 @@ impl Decoder for Metaspace {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use regex::Regex;
|
||||
|
||||
use super::*;
|
||||
use crate::{OffsetReferential, OffsetType};
|
||||
|
||||
#[test]
|
||||
fn serialization() {
|
||||
let metaspace = Metaspace::new('_', true);
|
||||
let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true}"#;
|
||||
let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
|
||||
assert_eq!(serde_json::to_string(&metaspace).unwrap(), metaspace_s);
|
||||
assert_eq!(
|
||||
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
|
||||
@ -118,8 +173,7 @@ mod tests {
|
||||
|
||||
// Also check it can deserialize previous versions
|
||||
let metaspace = Metaspace::new('_', true);
|
||||
let metaspace_s =
|
||||
r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true}"#;
|
||||
let metaspace_s = r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
|
||||
assert_eq!(
|
||||
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
|
||||
metaspace
|
||||
@ -188,6 +242,121 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_legacy_meta_space() {
|
||||
assert_eq!(
|
||||
Metaspace::new('▁', true),
|
||||
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
|
||||
);
|
||||
|
||||
let mut pretok = Metaspace::new('▁', true);
|
||||
pretok.set_prepend_scheme(PrependScheme::Always);
|
||||
assert_eq!(
|
||||
pretok,
|
||||
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
|
||||
);
|
||||
|
||||
pretok.set_prepend_scheme(PrependScheme::Never);
|
||||
assert_eq!(
|
||||
pretok,
|
||||
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Never)
|
||||
);
|
||||
|
||||
pretok.set_prepend_scheme(PrependScheme::First);
|
||||
assert_eq!(
|
||||
pretok,
|
||||
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::First)
|
||||
);
|
||||
|
||||
let mut pretokenized = PreTokenizedString::from("Hey my friend <s>how▁are you");
|
||||
let re_ref = Regex::new(r"(<s>)").unwrap();
|
||||
pretokenized
|
||||
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
|
||||
.expect("Bad split");
|
||||
|
||||
pretok.pre_tokenize(&mut pretokenized).unwrap();
|
||||
assert_eq!(
|
||||
pretokenized
|
||||
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
|
||||
.into_iter()
|
||||
.map(|(s, o, _)| (s, o))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("▁Hey", (0, 6)),
|
||||
("▁my", (6, 11)),
|
||||
("▁friend", (11, 20)),
|
||||
("▁", (20, 23)),
|
||||
("<s>", (23, 26)),
|
||||
("how", (26, 29)),
|
||||
("▁are", (29, 35)),
|
||||
("▁you", (35, 41))
|
||||
]
|
||||
);
|
||||
pretok.set_prepend_scheme(PrependScheme::Always);
|
||||
pretok.pre_tokenize(&mut pretokenized).unwrap();
|
||||
assert_eq!(
|
||||
pretokenized
|
||||
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
|
||||
.into_iter()
|
||||
.map(|(s, o, _)| (s, o))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("▁Hey", (0, 6)),
|
||||
("▁my", (6, 11)),
|
||||
("▁friend", (11, 20)),
|
||||
("▁", (20, 23)),
|
||||
("▁<s>", (23, 29)),
|
||||
("▁how", (29, 35)),
|
||||
("▁are", (35, 41)),
|
||||
("▁you", (41, 47))
|
||||
]
|
||||
);
|
||||
|
||||
pretok.set_prepend_scheme(PrependScheme::First);
|
||||
let mut pretokenized = PreTokenizedString::from(" Hey <s>how"); // test with prefix
|
||||
pretokenized
|
||||
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
|
||||
.expect("Bad split");
|
||||
pretok.pre_tokenize(&mut pretokenized).unwrap();
|
||||
assert_eq!(
|
||||
pretokenized
|
||||
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
|
||||
.into_iter()
|
||||
.map(|(s, o, _)| (s, o))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("▁Hey", (0, 6)),
|
||||
("▁", (6, 9)),
|
||||
("<s>", (9, 12)),
|
||||
("how", (12, 15))
|
||||
]
|
||||
);
|
||||
|
||||
let mut pretokenized = PreTokenizedString::from(" Hey <s>how <s>are <s> you"); // test with many splits
|
||||
pretokenized
|
||||
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
|
||||
.expect("Bad split");
|
||||
pretok.pre_tokenize(&mut pretokenized).unwrap();
|
||||
assert_eq!(
|
||||
pretokenized
|
||||
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
|
||||
.into_iter()
|
||||
.map(|(s, o, _)| (s, o))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("▁Hey", (0, 6)),
|
||||
("▁", (6, 9)),
|
||||
("<s>", (9, 12)),
|
||||
("how", (12, 15)),
|
||||
("▁", (15, 18)),
|
||||
("<s>", (18, 21)),
|
||||
("are", (21, 24)),
|
||||
("▁", (24, 27)),
|
||||
("<s>", (27, 30)),
|
||||
("▁you", (30, 36))
|
||||
]
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn decode() {
|
||||
let decoder = Metaspace::new('▁', true);
|
||||
|
@ -104,6 +104,34 @@ mod tests {
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
|
||||
]))
|
||||
);
|
||||
|
||||
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
|
||||
r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true, "prepend_scheme":"first"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
pre_tokenizer,
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme(
|
||||
'▁',
|
||||
true,
|
||||
metaspace::PrependScheme::First
|
||||
))
|
||||
);
|
||||
|
||||
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
|
||||
r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true, "prepend_scheme":"always"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
pre_tokenizer,
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme(
|
||||
'▁',
|
||||
true,
|
||||
metaspace::PrependScheme::Always
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
Reference in New Issue
Block a user