Refactor metaspace (#1476)

* version = "0.15.3-dev-0”

Improve performances of meta space, but also just fix it.

(transformers) ➜  transformers git:(refactor-default-llama) ✗ python ../scripts/gemma-dummy.py
Token indices sequence length is longer than the specified maximum sequence length for this model (14999 > 2048). Running this sequence through the model will result in indexing errors
['<REPR_END>', '▁inform', '<s>', '.', '▁Hey', '<unk>', '.', '▁', '▁', '▁', '▁', '▁', '▁', '▁.']
['▁inform', '<s>', '.', '▁Hey', '<unk>', '.', '▁', '▁', '▁', '▁', '▁', '▁', '▁.']
[0.0006330013275146484, 0.0014591217041015625, 0.015890836715698242, 0.18584918975830078, 2.1726326942443848]
(transformers) ➜  transformers git:(refactor-default-llama) ✗ python ../scripts/gemma-dummy.py
Token indices sequence length is longer than the specified maximum sequence length for this model (10000 > 2048). Running this sequence through the model will result in indexing errors
['<REPR_END>', 'in', 'form', '<s>', '.', '▁Hey', '<unk>', '.', '▁▁▁▁▁▁', '▁.']
['in', 'form', '<s>', '.', '▁Hey', '<unk>', '.', '▁▁▁▁▁▁', '▁.']
[0.0008409023284912109, 0.0008909702301025391, 0.00882411003112793, 0.10214710235595703, 1.187899112701416]

* well what do we have

* nit

* be BC with non legacy

* unrelated change for clippy

* fix test

* splitting is a must for word_ids

* fmt and lint

* Fixing everything (hopefully better).

* Fixing node.

* Including yarn.lock

* Lint.

* Stubs.

* revert to use split

* fix merge issues

* fix tests

* finish fixing tests

* ruff

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Arthur
2024-03-30 10:27:24 +01:00
committed by GitHub
parent 6153126b22
commit 09069717e9
21 changed files with 1672 additions and 1515 deletions

View File

@ -11,7 +11,11 @@ export function ctcDecoder(
cleanup?: boolean | undefined | null,
): Decoder
export function fuseDecoder(): Decoder
export function metaspaceDecoder(replacement?: string = '▁', addPrefixSpace?: bool = true): Decoder
export function metaspaceDecoder(
replacement?: string = '▁',
prependScheme?: prepend_scheme = 'always',
split?: split = true,
): Decoder
export function replaceDecoder(pattern: string, content: string): Decoder
export function sequenceDecoder(decoders: Array<Decoder>): Decoder
export function stripDecoder(content: string, left: number, right: number): Decoder
@ -89,7 +93,11 @@ export function byteLevelAlphabet(): Array<string>
export function whitespacePreTokenizer(): PreTokenizer
export function whitespaceSplitPreTokenizer(): PreTokenizer
export function bertPreTokenizer(): PreTokenizer
export function metaspacePreTokenizer(replacement?: string = '▁', addPrefixSpace?: bool = true): PreTokenizer
export function metaspacePreTokenizer(
replacement?: string = '▁',
prependScheme?: prepend_scheme = 'always',
split?: split = true,
): PreTokenizer
export function splitPreTokenizer(pattern: string, behavior: string, invert?: boolean | undefined | null): PreTokenizer
export function punctuationPreTokenizer(behavior?: string | undefined | null): PreTokenizer
export function sequencePreTokenizer(preTokenizers: Array<PreTokenizer>): PreTokenizer

View File

@ -219,6 +219,43 @@ switch (platform) {
loadError = e
}
break
case 'riscv64':
if (isMusl()) {
localFileExisted = existsSync(join(__dirname, 'tokenizers.linux-riscv64-musl.node'))
try {
if (localFileExisted) {
nativeBinding = require('./tokenizers.linux-riscv64-musl.node')
} else {
nativeBinding = require('tokenizers-linux-riscv64-musl')
}
} catch (e) {
loadError = e
}
} else {
localFileExisted = existsSync(join(__dirname, 'tokenizers.linux-riscv64-gnu.node'))
try {
if (localFileExisted) {
nativeBinding = require('./tokenizers.linux-riscv64-gnu.node')
} else {
nativeBinding = require('tokenizers-linux-riscv64-gnu')
}
} catch (e) {
loadError = e
}
}
break
case 's390x':
localFileExisted = existsSync(join(__dirname, 'tokenizers.linux-s390x-gnu.node'))
try {
if (localFileExisted) {
nativeBinding = require('./tokenizers.linux-s390x-gnu.node')
} else {
nativeBinding = require('tokenizers-linux-s390x-gnu')
}
} catch (e) {
loadError = e
}
break
default:
throw new Error(`Unsupported architecture on Linux: ${arch}`)
}

View File

@ -1,6 +1,6 @@
{
"name": "tokenizers",
"version": "0.14.0-dev0",
"version": "0.15.3-dev0",
"repository": {
"type": "git",
"url": "git+https://github.com/huggingface/tokenizers.git"

View File

@ -90,9 +90,11 @@ pub fn fuse_decoder() -> Decoder {
#[napi]
pub fn metaspace_decoder(
#[napi(ts_arg_type = "string = '▁'")] replacement: Option<String>,
#[napi(ts_arg_type = "bool = true")] add_prefix_space: Option<bool>,
#[napi(ts_arg_type = "prepend_scheme = 'always'")] prepend_scheme: Option<String>,
#[napi(ts_arg_type = "split = true")] split: Option<bool>,
) -> Result<Decoder> {
let add_prefix_space = add_prefix_space.unwrap_or(true);
use tk::pre_tokenizers::metaspace::PrependScheme;
let split = split.unwrap_or(true);
let replacement = replacement.unwrap_or("".to_string());
if replacement.chars().count() != 1 {
return Err(Error::from_reason(
@ -100,9 +102,20 @@ pub fn metaspace_decoder(
));
}
let replacement = replacement.chars().next().unwrap();
let prepend_scheme: PrependScheme =
match prepend_scheme.unwrap_or(String::from("always")).as_str() {
"always" => PrependScheme::Always,
"first" => PrependScheme::First,
"never" => PrependScheme::Never,
_ => {
return Err(Error::from_reason(
"prepend_scheme is supposed to be either 'always', 'first' or 'never'",
));
}
};
Ok(Decoder {
decoder: Some(Arc::new(RwLock::new(
tk::decoders::metaspace::Metaspace::new(replacement, add_prefix_space).into(),
tk::decoders::metaspace::Metaspace::new(replacement, prepend_scheme, split).into(),
))),
})
}

View File

@ -155,9 +155,11 @@ pub fn bert_pre_tokenizer() -> PreTokenizer {
#[napi]
pub fn metaspace_pre_tokenizer(
#[napi(ts_arg_type = "string = '▁'")] replacement: Option<String>,
#[napi(ts_arg_type = "bool = true")] add_prefix_space: Option<bool>,
#[napi(ts_arg_type = "prepend_scheme = 'always'")] prepend_scheme: Option<String>,
#[napi(ts_arg_type = "split = true")] split: Option<bool>,
) -> Result<PreTokenizer> {
let add_prefix_space = add_prefix_space.unwrap_or(true);
use tk::pre_tokenizers::metaspace::PrependScheme;
let split = split.unwrap_or(true);
let replacement = replacement.unwrap_or("".to_string());
if replacement.chars().count() != 1 {
return Err(Error::from_reason(
@ -165,10 +167,21 @@ pub fn metaspace_pre_tokenizer(
));
}
let replacement = replacement.chars().next().unwrap();
let prepend_scheme: PrependScheme =
match prepend_scheme.unwrap_or(String::from("always")).as_str() {
"always" => PrependScheme::Always,
"first" => PrependScheme::First,
"never" => PrependScheme::Never,
_ => {
return Err(Error::from_reason(
"prepend_scheme is supposed to be either 'always', 'first' or 'never'",
));
}
};
Ok(PreTokenizer {
pretok: Some(Arc::new(RwLock::new(
tk::pre_tokenizers::metaspace::Metaspace::new(replacement, add_prefix_space).into(),
tk::pre_tokenizers::metaspace::Metaspace::new(replacement, prepend_scheme, split).into(),
))),
})
}

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ check_dirs := examples py_src/tokenizers tests
style:
python stub.py
ruff check $(check_dirs) --fix
ruff format $(check_dirs)t
ruff format $(check_dirs)
# Check the source code is formatted correctly
check-style:

View File

@ -156,7 +156,7 @@ class Metaspace(Decoder):
Whether to add a space to the first word if there isn't already one. This
lets us treat `hello` exactly like `say hello`.
"""
def __init__(self, replacement="", add_prefix_space=True):
def __init__(self, replacement="", prepend_scheme="always", split=True):
pass
def decode(self, tokens):

View File

@ -32,8 +32,9 @@ class SentencePieceBPETokenizer(BaseTokenizer):
tokenizer.add_special_tokens([str(unk_token)])
tokenizer.normalizer = NFKC()
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
prepend_scheme = "always" if add_prefix_space else "never"
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
parameters = {
"model": "SentencePieceBPE",

View File

@ -29,8 +29,9 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
tokenizer.normalizer = normalizers.Sequence(
[normalizers.Nmt(), normalizers.NFKC(), normalizers.Replace(Regex(" {2,}"), " ")]
)
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
prepend_scheme = "always" if add_prefix_space else "never"
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
parameters = {
"model": "SentencePieceUnigram",

View File

@ -274,7 +274,7 @@ class Metaspace(PreTokenizer):
Whether to add a space to the first word if there isn't already one. This
lets us treat `hello` exactly like `say hello`.
"""
def __init__(self, replacement="_", add_prefix_space=True):
def __init__(self, replacement="_", prepend_scheme="always", split=True):
pass
def pre_tokenize(self, pretok):

View File

@ -1,5 +1,6 @@
use std::sync::{Arc, RwLock};
use crate::pre_tokenizers::from_string;
use crate::utils::PyChar;
use crate::utils::PyPattern;
use pyo3::exceptions;
@ -12,7 +13,7 @@ use tk::decoders::byte_fallback::ByteFallback;
use tk::decoders::byte_level::ByteLevel;
use tk::decoders::ctc::CTC;
use tk::decoders::fuse::Fuse;
use tk::decoders::metaspace::Metaspace;
use tk::decoders::metaspace::{Metaspace, PrependScheme};
use tk::decoders::sequence::Sequence;
use tk::decoders::strip::Strip;
use tk::decoders::wordpiece::WordPiece;
@ -322,22 +323,46 @@ impl PyMetaspaceDec {
}
#[getter]
fn get_add_prefix_space(self_: PyRef<Self>) -> bool {
getter!(self_, Metaspace, add_prefix_space)
fn get_split(self_: PyRef<Self>) -> bool {
getter!(self_, Metaspace, get_split())
}
#[setter]
fn set_add_prefix_space(self_: PyRef<Self>, add_prefix_space: bool) {
setter!(self_, Metaspace, add_prefix_space, add_prefix_space);
fn set_split(self_: PyRef<Self>, split: bool) {
setter!(self_, Metaspace, @set_split, split);
}
#[getter]
fn get_prepend_scheme(self_: PyRef<Self>) -> String {
// Assuming Metaspace has a method to get the prepend_scheme as a string
let scheme: PrependScheme = getter!(self_, Metaspace, get_prepend_scheme());
match scheme {
PrependScheme::First => "first",
PrependScheme::Never => "never",
PrependScheme::Always => "always",
}
.to_string()
}
#[setter]
fn set_prepend_scheme(self_: PyRef<Self>, prepend_scheme: String) -> PyResult<()> {
let scheme = from_string(prepend_scheme)?;
setter!(self_, Metaspace, @set_prepend_scheme, scheme);
Ok(())
}
#[new]
#[pyo3(signature = (replacement = PyChar('▁'), add_prefix_space = true), text_signature = "(self, replacement = \"\", add_prefix_space = True)")]
fn new(replacement: PyChar, add_prefix_space: bool) -> (Self, PyDecoder) {
(
#[pyo3(signature = (replacement = PyChar('▁'), prepend_scheme = String::from("always"), split = true), text_signature = "(self, replacement = \"\", prepend_scheme = \"always\", split = True)")]
fn new(
replacement: PyChar,
prepend_scheme: String,
split: bool,
) -> PyResult<(Self, PyDecoder)> {
let prepend_scheme = from_string(prepend_scheme)?;
Ok((
PyMetaspaceDec {},
Metaspace::new(replacement.0, add_prefix_space).into(),
)
Metaspace::new(replacement.0, prepend_scheme, split).into(),
))
}
}

View File

@ -452,7 +452,7 @@ impl PySequence {
}
}
fn from_string(string: String) -> Result<PrependScheme, PyErr> {
pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> {
let scheme = match string.as_str() {
"first" => PrependScheme::First,
"never" => PrependScheme::Never,
@ -495,13 +495,13 @@ impl PyMetaspace {
}
#[getter]
fn get_add_prefix_space(self_: PyRef<Self>) -> bool {
getter!(self_, Metaspace, add_prefix_space)
fn get_split(self_: PyRef<Self>) -> bool {
getter!(self_, Metaspace, get_split())
}
#[setter]
fn set_add_prefix_space(self_: PyRef<Self>, add_prefix_space: bool) {
setter!(self_, Metaspace, add_prefix_space, add_prefix_space);
fn set_split(self_: PyRef<Self>, split: bool) {
setter!(self_, Metaspace, @set_split, split);
}
#[getter]
@ -524,23 +524,15 @@ impl PyMetaspace {
}
#[new]
#[pyo3(signature = (replacement = PyChar('▁'), add_prefix_space = true, prepend_scheme=None, **_kwargs), text_signature = "(self, replacement=\"_\", add_prefix_space=True)")]
#[pyo3(signature = (replacement = PyChar('▁'), prepend_scheme=String::from("always"), split=true), text_signature = "(self, replacement=\"_\", prepend_scheme=\"always\", split=True)")]
fn new(
replacement: PyChar,
add_prefix_space: bool,
prepend_scheme: Option<String>,
_kwargs: Option<&PyDict>,
prepend_scheme: String,
split: bool,
) -> PyResult<(Self, PyPreTokenizer)> {
// Create a new Metaspace instance
let mut new_instance: Metaspace = Metaspace::new(replacement.0, add_prefix_space);
// If a prepend scheme is provided, set it
if let Some(prepend_scheme) = prepend_scheme {
match from_string(prepend_scheme) {
Ok(prepend_scheme_enum) => new_instance.set_prepend_scheme(prepend_scheme_enum),
Err(err) => return Err(err),
}
}
let prepend_scheme = from_string(prepend_scheme)?;
let new_instance: Metaspace = Metaspace::new(replacement.0, prepend_scheme, split);
Ok((PyMetaspace {}, new_instance.into()))
}
}

View File

@ -126,7 +126,7 @@ class TestMetaspace:
assert Metaspace(replacement="-") is not None
with pytest.raises(ValueError, match="expected a string of length 1"):
Metaspace(replacement="")
assert Metaspace(add_prefix_space=True) is not None
assert Metaspace(prepend_scheme="always") is not None
assert isinstance(Metaspace(), Decoder)
assert isinstance(Metaspace(), Metaspace)
assert isinstance(pickle.loads(pickle.dumps(Metaspace())), Metaspace)
@ -134,20 +134,20 @@ class TestMetaspace:
def test_decoding(self):
decoder = Metaspace()
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == "My name is John"
decoder = Metaspace(replacement="-", add_prefix_space=False)
decoder = Metaspace(replacement="-", prepend_scheme="never")
assert decoder.decode(["-My", "-name", "-is", "-John"]) == " My name is John"
def test_can_modify(self):
decoder = Metaspace(replacement="*", add_prefix_space=False)
decoder = Metaspace(replacement="*", prepend_scheme="never")
assert decoder.replacement == "*"
assert decoder.add_prefix_space == False
assert decoder.prepend_scheme == "never"
# Modify these
decoder.replacement = "&"
assert decoder.replacement == "&"
decoder.add_prefix_space = True
assert decoder.add_prefix_space == True
decoder.prepend_scheme = "first"
assert decoder.prepend_scheme == "first"
class TestBPEDecoder:

View File

@ -94,24 +94,27 @@ class TestMetaspace:
assert Metaspace(replacement="-") is not None
with pytest.raises(ValueError, match="expected a string of length 1"):
Metaspace(replacement="")
assert Metaspace(add_prefix_space=True) is not None
assert Metaspace(prepend_scheme="always") is not None
assert isinstance(Metaspace(), PreTokenizer)
assert isinstance(Metaspace(), Metaspace)
assert isinstance(pickle.loads(pickle.dumps(Metaspace())), Metaspace)
def test_can_modify(self):
pretok = Metaspace(replacement="$", add_prefix_space=False)
pretok = Metaspace(replacement="$", prepend_scheme="never")
assert pretok.replacement == "$"
assert pretok.add_prefix_space == False
assert pretok.prepend_scheme == "never"
assert pretok.split == True
# Modify these
pretok.replacement = "%"
assert pretok.replacement == "%"
pretok.add_prefix_space = True
assert pretok.add_prefix_space == True
pretok.prepend_scheme = "never"
assert pretok.prepend_scheme == "never"
pretok.prepend_scheme = "first"
assert pretok.prepend_scheme == "first"
pretok.split = True
assert pretok.split == True
class TestCharDelimiterSplit:

View File

@ -487,3 +487,51 @@ class TestTokenizer:
tokenizer.add_tokens(["of_text>"])
output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
assert output.tokens == ["▁Hey", "▁there", "<", "end", "_", "of_text>", "▁dear", "<eot>", "▁friend", "!"]
def test_splitting(self):
tokenizer = Tokenizer.from_pretrained("hf-internal-testing/llama-new-metaspace")
tokenizer.pre_tokenizer.split = False
tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)])
assert tokenizer.encode("<REPR_END>inform<s>. Hey. .", add_special_tokens=False).tokens == [
"<REPR_END>",
"in",
"form",
"<s>",
".",
"▁Hey",
".",
"▁▁▁▁▁▁",
"▁.",
]
assert tokenizer.encode("<REPR_END>inform<s>. Hey. .", add_special_tokens=False).ids == [
32000,
262,
689,
1,
29889,
18637,
29889,
539,
869,
]
assert tokenizer.encode("inform<s>. Hey. .").tokens == [
"<s>",
"▁inform",
"<s>",
".",
"▁Hey",
".",
"▁▁▁▁▁▁",
"▁.",
]
assert tokenizer.encode("inform<s>. Hey. .", add_special_tokens=False).tokens == [
"▁inform",
"<s>",
".",
"▁Hey",
".",
"▁▁▁▁▁▁",
"▁.",
]

View File

@ -73,14 +73,19 @@ mod tests {
#[test]
fn decoder_serialization() {
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let oldjson = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let olddecoder: DecoderWrapper = serde_json::from_str(oldjson).unwrap();
let oldserialized = serde_json::to_string(&olddecoder).unwrap();
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always","split":true}]}"#;
assert_eq!(oldserialized, json);
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
let serialized = serde_json::to_string(&decoder).unwrap();
assert_eq!(serialized, json);
}
#[test]
fn decoder_serialization_other_no_arg() {
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always","split":true}]}"#;
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
let serialized = serde_json::to_string(&decoder).unwrap();
assert_eq!(serialized, json);
@ -88,7 +93,7 @@ mod tests {
#[test]
fn decoder_serialization_no_decode() {
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always"}]}"#;
assert!(serde_json::from_str::<DecoderWrapper>(json).is_err());
}
}

View File

@ -1,5 +1,5 @@
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
use serde::{Deserialize, Deserializer, Serialize};
use serde::{de, Deserialize, Deserializer, Serialize};
/// Enum representing options for the metaspace prepending scheme.
#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)]
@ -19,8 +19,8 @@ pub enum PrependScheme {
#[serde(tag = "type")]
pub struct Metaspace {
replacement: char,
pub add_prefix_space: bool,
pub prepend_scheme: PrependScheme,
pub split: bool,
#[serde(skip)]
str_rep: String,
}
@ -44,42 +44,40 @@ impl<'de> Deserialize<'de> for Metaspace {
#[serde(rename = "type")]
_type: Type,
replacement: char,
pub add_prefix_space: bool,
pub add_prefix_space: Option<bool>,
#[serde(default = "default_prepend_scheme_value")]
pub prepend_scheme: PrependScheme,
#[serde(skip, rename = "str_rep")]
_str_rep: String,
pub split: Option<bool>,
#[serde(rename = "str_rep")]
_str_rep: Option<String>,
}
let helper = MetaspaceHelper::deserialize(deserializer)?;
let instance = Self::new_with_prepend_scheme(
let mut helper = MetaspaceHelper::deserialize(deserializer)?;
if let Some(false) = helper.add_prefix_space {
if helper.prepend_scheme != PrependScheme::Never {
return Err(de::Error::custom(
"add_prefix_space does not match declared prepend_scheme",
));
}
helper.prepend_scheme = PrependScheme::Never;
}
let instance = Self::new(
helper.replacement,
helper.add_prefix_space,
helper.prepend_scheme,
helper.split.unwrap_or(true),
);
Ok(instance)
}
}
impl Metaspace {
pub fn new(replacement: char, add_prefix_space: bool) -> Self {
Self::new_with_prepend_scheme(
replacement,
add_prefix_space,
PrependScheme::Always, // always prepend for legacy purpose
)
}
pub fn new_with_prepend_scheme(
replacement: char,
add_prefix_space: bool,
prepend_scheme: PrependScheme,
) -> Self {
pub fn new(replacement: char, prepend_scheme: PrependScheme, split: bool) -> Self {
Self {
replacement,
str_rep: replacement.to_string(),
add_prefix_space,
prepend_scheme,
split,
}
}
@ -92,6 +90,14 @@ impl Metaspace {
self.str_rep = replacement.to_string();
}
pub fn get_split(&self) -> bool {
self.split
}
pub fn set_split(&mut self, split: bool) {
self.split = split;
}
pub fn get_prepend_scheme(&self) -> PrependScheme {
self.prepend_scheme
}
@ -103,28 +109,34 @@ impl Metaspace {
impl Default for Metaspace {
fn default() -> Self {
Self::new('▁', true)
Self::new('▁', PrependScheme::Always, true)
}
}
impl PreTokenizer for Metaspace {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
let mut first_split = true;
pretokenized.split(|_, mut normalized| {
normalized.replace(' ', &self.str_rep)?;
if self.add_prefix_space && !normalized.get().starts_with(self.replacement) {
if self.prepend_scheme == PrependScheme::Always {
normalized.prepend(&self.str_rep);
} else if self.prepend_scheme == PrependScheme::First && first_split {
normalized.prepend(&self.str_rep);
first_split = false;
match self.prepend_scheme {
PrependScheme::Always => {
if !normalized.get().starts_with(self.replacement) {
normalized.prepend(&self.str_rep);
}
}
PrependScheme::First => {
if !normalized.get().starts_with(self.replacement)
&& normalized.offsets_original().0 == 0
{
normalized.prepend(&self.str_rep);
}
}
PrependScheme::Never => {}
};
if self.split {
normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
} else {
first_split = false;
Ok(vec![normalized])
}
normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
})
}
}
@ -139,7 +151,7 @@ impl Decoder for Metaspace {
.chars()
.flat_map(|c| {
if c == self.replacement {
if i == 0 && self.add_prefix_space {
if i == 0 && self.prepend_scheme != PrependScheme::Never {
None
} else {
Some(' ')
@ -163,8 +175,9 @@ mod tests {
#[test]
fn serialization() {
let metaspace = Metaspace::new('_', true);
let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
let metaspace = Metaspace::new('_', PrependScheme::Always, true);
let metaspace_s =
r#"{"type":"Metaspace","replacement":"_","prepend_scheme":"always","split":true}"#;
assert_eq!(serde_json::to_string(&metaspace).unwrap(), metaspace_s);
assert_eq!(
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
@ -172,7 +185,10 @@ mod tests {
);
// Also check it can deserialize previous versions
let metaspace = Metaspace::new('_', true);
let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":false,"prepend_scheme":"always"}"#;
assert!(serde_json::from_str::<Metaspace>(metaspace_s).is_err(),);
let metaspace = Metaspace::new('_', PrependScheme::Always, true);
let metaspace_s = r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
assert_eq!(
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
@ -188,7 +204,7 @@ mod tests {
#[test]
fn basic() {
let pretok = Metaspace::new('▁', true);
let pretok = Metaspace::new('▁', PrependScheme::Always, true);
let mut pretokenized = PreTokenizedString::from("Hey friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
@ -211,7 +227,7 @@ mod tests {
#[test]
fn multiple_spaces() {
let pretok = Metaspace::new('▁', true);
let pretok = Metaspace::new('▁', PrependScheme::Always, true);
let mut pretokenized = PreTokenizedString::from("Hey friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
@ -244,30 +260,17 @@ mod tests {
#[test]
fn non_legacy_meta_space() {
assert_eq!(
Metaspace::new('▁', true),
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
);
let mut pretok = Metaspace::new('▁', true);
let mut pretok = Metaspace::new('▁', PrependScheme::Always, true);
pretok.set_prepend_scheme(PrependScheme::Always);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
);
assert_eq!(pretok, Metaspace::new('▁', PrependScheme::Always, true));
pretok.set_prepend_scheme(PrependScheme::Never);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Never)
);
assert_eq!(pretok, Metaspace::new('▁', PrependScheme::Never, true));
pretok.set_prepend_scheme(PrependScheme::First);
assert_eq!(
pretok,
Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::First)
);
assert_eq!(pretok, Metaspace::new('▁', PrependScheme::First, true));
let pretok = Metaspace::new('▁', PrependScheme::First, false);
let mut pretokenized = PreTokenizedString::from("Hey my friend <s>how▁are you");
let re_ref = Regex::new(r"(<s>)").unwrap();
pretokenized
@ -282,17 +285,12 @@ mod tests {
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("▁my", (6, 11)),
("▁friend", (11, 20)),
("", (20, 23)),
("▁Hey▁my▁friend▁", (0, 23)),
("<s>", (23, 26)),
("how", (26, 29)),
("▁are", (29, 35)),
("▁you", (35, 41))
("how▁are▁you", (26, 41))
]
);
pretok.set_prepend_scheme(PrependScheme::Always);
let pretok = Metaspace::new('▁', PrependScheme::Always, true);
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
@ -312,7 +310,7 @@ mod tests {
]
);
pretok.set_prepend_scheme(PrependScheme::First);
let pretok = Metaspace::new('▁', PrependScheme::First, false);
let mut pretokenized = PreTokenizedString::from(" Hey <s>how"); // test with prefix
pretokenized
.split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
@ -324,12 +322,7 @@ mod tests {
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("", (6, 9)),
("<s>", (9, 12)),
("how", (12, 15))
]
vec![("▁Hey▁", (0, 9)), ("<s>", (9, 12)), ("how", (12, 15))]
);
let mut pretokenized = PreTokenizedString::from(" Hey <s>how <s>are <s> you"); // test with many splits
@ -344,14 +337,11 @@ mod tests {
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("", (6, 9)),
("▁Hey", (0, 9)),
("<s>", (9, 12)),
("how", (12, 15)),
("", (15, 18)),
("how", (12, 18)),
("<s>", (18, 21)),
("are", (21, 24)),
("", (24, 27)),
("are", (21, 27)),
("<s>", (27, 30)),
("▁you", (30, 36))
]
@ -359,10 +349,16 @@ mod tests {
}
#[test]
fn decode() {
let decoder = Metaspace::new('▁', true);
let decoder = Metaspace::new('▁', PrependScheme::Always, true);
let res = decoder
.decode_chain(vec!["▁Hey".into(), "▁friend!".into()])
.unwrap();
assert_eq!(res, vec!["Hey", " friend!"])
assert_eq!(res, vec!["Hey", " friend!"]);
let decoder = Metaspace::new('▁', PrependScheme::Never, true);
let res = decoder
.decode_chain(vec!["▁Hey".into(), "▁friend!".into()])
.unwrap();
assert_eq!(res, vec![" Hey", " friend!"]);
}
}

View File

@ -71,6 +71,7 @@ impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts);
#[cfg(test)]
mod tests {
use super::metaspace::PrependScheme;
use super::*;
#[test]
@ -81,7 +82,7 @@ mod tests {
pre_tokenizer,
PreTokenizerWrapper::Sequence(Sequence::new(vec![
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
]))
);
@ -92,7 +93,7 @@ mod tests {
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
);
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#).unwrap();
@ -101,7 +102,7 @@ mod tests {
pre_tokenizer,
PreTokenizerWrapper::Sequence(Sequence::new(vec![
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
]))
);
@ -112,10 +113,10 @@ mod tests {
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme(
PreTokenizerWrapper::Metaspace(Metaspace::new(
'▁',
true,
metaspace::PrependScheme::First
metaspace::PrependScheme::First,
true
))
);
@ -126,10 +127,10 @@ mod tests {
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::Metaspace(Metaspace::new_with_prepend_scheme(
PreTokenizerWrapper::Metaspace(Metaspace::new(
'▁',
true,
metaspace::PrependScheme::Always
metaspace::PrependScheme::Always,
true
))
);
}

View File

@ -448,7 +448,7 @@ impl TemplateProcessingBuilder {
}
};
let empty = vec![];
let empty = [];
let missing: HashSet<&str> = self
.single
.as_ref()

View File

@ -928,7 +928,6 @@ where
pretokenized: P,
) -> Result<PreTokenizedString> {
let mut pretokenized: PreTokenizedString = pretokenized.into();
if let Some(ref pretok) = self.pre_tokenizer {
pretok.pre_tokenize(&mut pretokenized)?;
}