Add more support for tiktoken based tokenizers (#1493)

* first commit

* update

* clippy

* lint

* clippy and lint

* fmt

* revert print

* 😈

* style

* add a test

* more fmt

* Use ignore_merges

* stub

* fix

* update

* Update tokenizers/src/models/bpe/model.rs

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* update

* rust lint

* dob; t repeat yourself

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Arthur
2024-04-15 17:26:36 +02:00
committed by GitHub
parent 6e58f838b3
commit 914576f7ed
3 changed files with 135 additions and 9 deletions

View File

@ -112,6 +112,9 @@ class BPE(Model):
byte_fallback (:obj:`bool`, `optional`): byte_fallback (:obj:`bool`, `optional`):
Whether to use spm byte-fallback trick (defaults to False) Whether to use spm byte-fallback trick (defaults to False)
ignore_merges (:obj:`bool`, `optional`):
Whether or not to match tokens with the vocab before using merges.
""" """
def __init__( def __init__(
self, self,
@ -124,6 +127,7 @@ class BPE(Model):
end_of_word_suffix=None, end_of_word_suffix=None,
fuse_unk=None, fuse_unk=None,
byte_fallback=False, byte_fallback=False,
ignore_merges=False,
): ):
pass pass

View File

@ -253,6 +253,9 @@ impl PyModel {
/// ///
/// byte_fallback (:obj:`bool`, `optional`): /// byte_fallback (:obj:`bool`, `optional`):
/// Whether to use spm byte-fallback trick (defaults to False) /// Whether to use spm byte-fallback trick (defaults to False)
///
/// ignore_merges (:obj:`bool`, `optional`):
/// Whether or not to match tokens with the vocab before using merges.
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")] #[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")]
pub struct PyBPE {} pub struct PyBPE {}
@ -279,6 +282,7 @@ impl PyBPE {
"end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?), "end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?),
"fuse_unk" => builder = builder.fuse_unk(value.extract()?), "fuse_unk" => builder = builder.fuse_unk(value.extract()?),
"byte_fallback" => builder = builder.byte_fallback(value.extract()?), "byte_fallback" => builder = builder.byte_fallback(value.extract()?),
"ignore_merges" => builder = builder.ignore_merges(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key), _ => println!("Ignored unknown kwarg option {}", key),
}; };
} }
@ -396,11 +400,19 @@ impl PyBPE {
fn set_byte_fallback(self_: PyRef<Self>, byte_fallback: bool) { fn set_byte_fallback(self_: PyRef<Self>, byte_fallback: bool) {
setter!(self_, BPE, byte_fallback, byte_fallback); setter!(self_, BPE, byte_fallback, byte_fallback);
} }
#[getter]
fn get_ignore_merges(self_: PyRef<Self>) -> bool {
getter!(self_, BPE, ignore_merges)
}
#[setter]
fn set_ignore_merges(self_: PyRef<Self>, ignore_merges: bool) {
setter!(self_, BPE, ignore_merges, ignore_merges);
}
#[new] #[new]
#[pyo3( #[pyo3(
signature = (vocab=None, merges=None, **kwargs), signature = (vocab=None, merges=None, **kwargs),
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False)")] text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False, ignore_merges=False)")]
fn new( fn new(
py: Python<'_>, py: Python<'_>,
vocab: Option<PyVocab>, vocab: Option<PyVocab>,

View File

@ -28,6 +28,7 @@ struct Config {
end_of_word_suffix: Option<String>, end_of_word_suffix: Option<String>,
fuse_unk: bool, fuse_unk: bool,
byte_fallback: bool, byte_fallback: bool,
ignore_merges: bool,
} }
/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration. /// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
@ -49,6 +50,7 @@ impl Default for BpeBuilder {
end_of_word_suffix: None, end_of_word_suffix: None,
fuse_unk: false, fuse_unk: false,
byte_fallback: false, byte_fallback: false,
ignore_merges: false,
}, },
} }
} }
@ -123,6 +125,12 @@ impl BpeBuilder {
self.config.byte_fallback = byte_fallback; self.config.byte_fallback = byte_fallback;
self self
} }
/// Set the `ignore_merges` option.
#[must_use]
pub fn ignore_merges(mut self, ignore_merges: bool) -> Self {
self.config.ignore_merges = ignore_merges;
self
}
/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration. /// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
pub fn build(mut self) -> Result<BPE> { pub fn build(mut self) -> Result<BPE> {
@ -190,6 +198,7 @@ impl BpeBuilder {
end_of_word_suffix: self.config.end_of_word_suffix, end_of_word_suffix: self.config.end_of_word_suffix,
fuse_unk: self.config.fuse_unk, fuse_unk: self.config.fuse_unk,
byte_fallback: self.config.byte_fallback, byte_fallback: self.config.byte_fallback,
ignore_merges: self.config.ignore_merges,
}) })
} }
} }
@ -219,6 +228,8 @@ pub struct BPE {
/// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"` /// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"`
/// for each byte in the unk token /// for each byte in the unk token
pub byte_fallback: bool, pub byte_fallback: bool,
/// Whether or not to direct output words if they are part of the vocab.
pub ignore_merges: bool,
} }
impl std::fmt::Debug for BPE { impl std::fmt::Debug for BPE {
@ -232,6 +243,7 @@ impl std::fmt::Debug for BPE {
.field("byte_fallback", &self.byte_fallback) .field("byte_fallback", &self.byte_fallback)
.field("vocab", &self.vocab.len()) .field("vocab", &self.vocab.len())
.field("merges", &self.merges.len()) .field("merges", &self.merges.len())
.field("ignore_merges", &self.ignore_merges)
.finish() .finish()
} }
} }
@ -258,6 +270,7 @@ impl Clone for BPE {
end_of_word_suffix: self.end_of_word_suffix.clone(), end_of_word_suffix: self.end_of_word_suffix.clone(),
fuse_unk: self.fuse_unk, fuse_unk: self.fuse_unk,
byte_fallback: self.byte_fallback, byte_fallback: self.byte_fallback,
ignore_merges: self.ignore_merges,
} }
} }
} }
@ -448,15 +461,19 @@ impl BPE {
fn tokenize_with_cache(&self, sequence: &str) -> Result<Vec<Token>> { fn tokenize_with_cache(&self, sequence: &str) -> Result<Vec<Token>> {
if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) { if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) {
Ok(self.word_to_tokens(hit).collect()) return Ok(self.word_to_tokens(hit).collect());
} else {
let word = self.merge_word(sequence)?;
let ret = self.word_to_tokens(&word).collect();
if let Some(ref cache) = self.cache {
cache.set(sequence.to_owned(), word);
}
Ok(ret)
} }
if self.ignore_merges {
if let Some(id) = self.vocab.get(sequence) {
return Ok(vec![Token::new(*id, sequence.to_string().clone(), (0, 0))]);
}
}
let word = self.merge_word(sequence)?;
let ret = self.word_to_tokens(&word).collect();
if let Some(ref cache) = self.cache {
cache.set(sequence.to_owned(), word);
}
Ok(ret)
} }
} }
@ -862,4 +879,97 @@ mod tests {
let tokens = bpe.tokenize("\n").unwrap(); let tokens = bpe.tokenize("\n").unwrap();
assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]); assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]);
} }
#[test]
fn test_ignore_merges() {
// 0x0A == '\n' in bytes
let vocab: Vocab = [
(".:.:".into(), 0),
("Ġbelirtilen".into(), 1),
(".".into(), 2),
(":".into(), 3),
("bel".into(), 4),
("irtilen".into(), 5),
("Ġ".into(), 6),
(".:".into(), 7),
("belirtilen".into(), 8),
(".:.".into(), 9),
("be".into(), 10),
("l".into(), 11),
("ir".into(), 12),
("ti".into(), 13),
("en".into(), 14),
("irtil".into(), 15),
("irti".into(), 16),
("i".into(), 17),
("r".into(), 18),
("t".into(), 19),
("b".into(), 20),
("e".into(), 21),
("n".into(), 22),
]
.iter()
.cloned()
.collect();
let mut bpe = BpeBuilder::default()
.vocab_and_merges(
vocab,
vec![
(".".into(), ":".into()),
("b".into(), "e".into()),
("be".into(), "l".into()),
("i".into(), "r".into()),
("t".into(), "i".into()),
("ir".into(), "ti".into()),
("e".into(), "n".into()),
("irti".into(), "l".into()),
],
)
.ignore_merges(true)
.build()
.unwrap();
let tokens = bpe.tokenize(".:.:").unwrap();
assert_eq!(tokens, vec![Token::new(0u32, ".:.:".into(), (0, 0))]);
let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
assert_eq!(tokens, vec![Token::new(1u32, "Ġbelirtilen".into(), (0, 0))]);
bpe.ignore_merges = false;
let tokens = bpe.tokenize(".:.:").unwrap();
assert_eq!(
tokens,
vec![
Token::new(7u32, ".:".into(), (0, 2)),
Token::new(7u32, ".:".into(), (2, 4))
]
);
let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
assert_eq!(
tokens,
vec![
Token {
id: 6,
value: "Ġ".into(),
offsets: (0, 2)
},
Token {
id: 4,
value: "bel".into(),
offsets: (2, 5)
},
Token {
id: 15,
value: "irtil".into(),
offsets: (5, 10)
},
Token {
id: 14,
value: "en".into(),
offsets: (10, 12)
}
]
)
}
} }