mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add more support for tiktoken based tokenizers (#1493)
* first commit * update * clippy * lint * clippy and lint * fmt * revert print * 😈 * style * add a test * more fmt * Use ignore_merges * stub * fix * update * Update tokenizers/src/models/bpe/model.rs Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * update * rust lint * dob; t repeat yourself --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@ -112,6 +112,9 @@ class BPE(Model):
|
|||||||
|
|
||||||
byte_fallback (:obj:`bool`, `optional`):
|
byte_fallback (:obj:`bool`, `optional`):
|
||||||
Whether to use spm byte-fallback trick (defaults to False)
|
Whether to use spm byte-fallback trick (defaults to False)
|
||||||
|
|
||||||
|
ignore_merges (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to match tokens with the vocab before using merges.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -124,6 +127,7 @@ class BPE(Model):
|
|||||||
end_of_word_suffix=None,
|
end_of_word_suffix=None,
|
||||||
fuse_unk=None,
|
fuse_unk=None,
|
||||||
byte_fallback=False,
|
byte_fallback=False,
|
||||||
|
ignore_merges=False,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -253,6 +253,9 @@ impl PyModel {
|
|||||||
///
|
///
|
||||||
/// byte_fallback (:obj:`bool`, `optional`):
|
/// byte_fallback (:obj:`bool`, `optional`):
|
||||||
/// Whether to use spm byte-fallback trick (defaults to False)
|
/// Whether to use spm byte-fallback trick (defaults to False)
|
||||||
|
///
|
||||||
|
/// ignore_merges (:obj:`bool`, `optional`):
|
||||||
|
/// Whether or not to match tokens with the vocab before using merges.
|
||||||
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")]
|
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")]
|
||||||
pub struct PyBPE {}
|
pub struct PyBPE {}
|
||||||
|
|
||||||
@ -279,6 +282,7 @@ impl PyBPE {
|
|||||||
"end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?),
|
"end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?),
|
||||||
"fuse_unk" => builder = builder.fuse_unk(value.extract()?),
|
"fuse_unk" => builder = builder.fuse_unk(value.extract()?),
|
||||||
"byte_fallback" => builder = builder.byte_fallback(value.extract()?),
|
"byte_fallback" => builder = builder.byte_fallback(value.extract()?),
|
||||||
|
"ignore_merges" => builder = builder.ignore_merges(value.extract()?),
|
||||||
_ => println!("Ignored unknown kwarg option {}", key),
|
_ => println!("Ignored unknown kwarg option {}", key),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -396,11 +400,19 @@ impl PyBPE {
|
|||||||
fn set_byte_fallback(self_: PyRef<Self>, byte_fallback: bool) {
|
fn set_byte_fallback(self_: PyRef<Self>, byte_fallback: bool) {
|
||||||
setter!(self_, BPE, byte_fallback, byte_fallback);
|
setter!(self_, BPE, byte_fallback, byte_fallback);
|
||||||
}
|
}
|
||||||
|
#[getter]
|
||||||
|
fn get_ignore_merges(self_: PyRef<Self>) -> bool {
|
||||||
|
getter!(self_, BPE, ignore_merges)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_ignore_merges(self_: PyRef<Self>, ignore_merges: bool) {
|
||||||
|
setter!(self_, BPE, ignore_merges, ignore_merges);
|
||||||
|
}
|
||||||
#[new]
|
#[new]
|
||||||
#[pyo3(
|
#[pyo3(
|
||||||
signature = (vocab=None, merges=None, **kwargs),
|
signature = (vocab=None, merges=None, **kwargs),
|
||||||
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False)")]
|
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False, ignore_merges=False)")]
|
||||||
fn new(
|
fn new(
|
||||||
py: Python<'_>,
|
py: Python<'_>,
|
||||||
vocab: Option<PyVocab>,
|
vocab: Option<PyVocab>,
|
||||||
|
@ -28,6 +28,7 @@ struct Config {
|
|||||||
end_of_word_suffix: Option<String>,
|
end_of_word_suffix: Option<String>,
|
||||||
fuse_unk: bool,
|
fuse_unk: bool,
|
||||||
byte_fallback: bool,
|
byte_fallback: bool,
|
||||||
|
ignore_merges: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
|
/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
|
||||||
@ -49,6 +50,7 @@ impl Default for BpeBuilder {
|
|||||||
end_of_word_suffix: None,
|
end_of_word_suffix: None,
|
||||||
fuse_unk: false,
|
fuse_unk: false,
|
||||||
byte_fallback: false,
|
byte_fallback: false,
|
||||||
|
ignore_merges: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -123,6 +125,12 @@ impl BpeBuilder {
|
|||||||
self.config.byte_fallback = byte_fallback;
|
self.config.byte_fallback = byte_fallback;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
/// Set the `ignore_merges` option.
|
||||||
|
#[must_use]
|
||||||
|
pub fn ignore_merges(mut self, ignore_merges: bool) -> Self {
|
||||||
|
self.config.ignore_merges = ignore_merges;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
|
/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
|
||||||
pub fn build(mut self) -> Result<BPE> {
|
pub fn build(mut self) -> Result<BPE> {
|
||||||
@ -190,6 +198,7 @@ impl BpeBuilder {
|
|||||||
end_of_word_suffix: self.config.end_of_word_suffix,
|
end_of_word_suffix: self.config.end_of_word_suffix,
|
||||||
fuse_unk: self.config.fuse_unk,
|
fuse_unk: self.config.fuse_unk,
|
||||||
byte_fallback: self.config.byte_fallback,
|
byte_fallback: self.config.byte_fallback,
|
||||||
|
ignore_merges: self.config.ignore_merges,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -219,6 +228,8 @@ pub struct BPE {
|
|||||||
/// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"`
|
/// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"`
|
||||||
/// for each byte in the unk token
|
/// for each byte in the unk token
|
||||||
pub byte_fallback: bool,
|
pub byte_fallback: bool,
|
||||||
|
/// Whether or not to direct output words if they are part of the vocab.
|
||||||
|
pub ignore_merges: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for BPE {
|
impl std::fmt::Debug for BPE {
|
||||||
@ -232,6 +243,7 @@ impl std::fmt::Debug for BPE {
|
|||||||
.field("byte_fallback", &self.byte_fallback)
|
.field("byte_fallback", &self.byte_fallback)
|
||||||
.field("vocab", &self.vocab.len())
|
.field("vocab", &self.vocab.len())
|
||||||
.field("merges", &self.merges.len())
|
.field("merges", &self.merges.len())
|
||||||
|
.field("ignore_merges", &self.ignore_merges)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -258,6 +270,7 @@ impl Clone for BPE {
|
|||||||
end_of_word_suffix: self.end_of_word_suffix.clone(),
|
end_of_word_suffix: self.end_of_word_suffix.clone(),
|
||||||
fuse_unk: self.fuse_unk,
|
fuse_unk: self.fuse_unk,
|
||||||
byte_fallback: self.byte_fallback,
|
byte_fallback: self.byte_fallback,
|
||||||
|
ignore_merges: self.ignore_merges,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -448,15 +461,19 @@ impl BPE {
|
|||||||
|
|
||||||
fn tokenize_with_cache(&self, sequence: &str) -> Result<Vec<Token>> {
|
fn tokenize_with_cache(&self, sequence: &str) -> Result<Vec<Token>> {
|
||||||
if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) {
|
if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) {
|
||||||
Ok(self.word_to_tokens(hit).collect())
|
return Ok(self.word_to_tokens(hit).collect());
|
||||||
} else {
|
|
||||||
let word = self.merge_word(sequence)?;
|
|
||||||
let ret = self.word_to_tokens(&word).collect();
|
|
||||||
if let Some(ref cache) = self.cache {
|
|
||||||
cache.set(sequence.to_owned(), word);
|
|
||||||
}
|
|
||||||
Ok(ret)
|
|
||||||
}
|
}
|
||||||
|
if self.ignore_merges {
|
||||||
|
if let Some(id) = self.vocab.get(sequence) {
|
||||||
|
return Ok(vec![Token::new(*id, sequence.to_string().clone(), (0, 0))]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let word = self.merge_word(sequence)?;
|
||||||
|
let ret = self.word_to_tokens(&word).collect();
|
||||||
|
if let Some(ref cache) = self.cache {
|
||||||
|
cache.set(sequence.to_owned(), word);
|
||||||
|
}
|
||||||
|
Ok(ret)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -862,4 +879,97 @@ mod tests {
|
|||||||
let tokens = bpe.tokenize("\n").unwrap();
|
let tokens = bpe.tokenize("\n").unwrap();
|
||||||
assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]);
|
assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ignore_merges() {
|
||||||
|
// 0x0A == '\n' in bytes
|
||||||
|
let vocab: Vocab = [
|
||||||
|
(".:.:".into(), 0),
|
||||||
|
("Ġbelirtilen".into(), 1),
|
||||||
|
(".".into(), 2),
|
||||||
|
(":".into(), 3),
|
||||||
|
("bel".into(), 4),
|
||||||
|
("irtilen".into(), 5),
|
||||||
|
("Ġ".into(), 6),
|
||||||
|
(".:".into(), 7),
|
||||||
|
("belirtilen".into(), 8),
|
||||||
|
(".:.".into(), 9),
|
||||||
|
("be".into(), 10),
|
||||||
|
("l".into(), 11),
|
||||||
|
("ir".into(), 12),
|
||||||
|
("ti".into(), 13),
|
||||||
|
("en".into(), 14),
|
||||||
|
("irtil".into(), 15),
|
||||||
|
("irti".into(), 16),
|
||||||
|
("i".into(), 17),
|
||||||
|
("r".into(), 18),
|
||||||
|
("t".into(), 19),
|
||||||
|
("b".into(), 20),
|
||||||
|
("e".into(), 21),
|
||||||
|
("n".into(), 22),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
let mut bpe = BpeBuilder::default()
|
||||||
|
.vocab_and_merges(
|
||||||
|
vocab,
|
||||||
|
vec![
|
||||||
|
(".".into(), ":".into()),
|
||||||
|
("b".into(), "e".into()),
|
||||||
|
("be".into(), "l".into()),
|
||||||
|
("i".into(), "r".into()),
|
||||||
|
("t".into(), "i".into()),
|
||||||
|
("ir".into(), "ti".into()),
|
||||||
|
("e".into(), "n".into()),
|
||||||
|
("irti".into(), "l".into()),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.ignore_merges(true)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let tokens = bpe.tokenize(".:.:").unwrap();
|
||||||
|
assert_eq!(tokens, vec![Token::new(0u32, ".:.:".into(), (0, 0))]);
|
||||||
|
|
||||||
|
let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
|
||||||
|
assert_eq!(tokens, vec![Token::new(1u32, "Ġbelirtilen".into(), (0, 0))]);
|
||||||
|
|
||||||
|
bpe.ignore_merges = false;
|
||||||
|
|
||||||
|
let tokens = bpe.tokenize(".:.:").unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
tokens,
|
||||||
|
vec![
|
||||||
|
Token::new(7u32, ".:".into(), (0, 2)),
|
||||||
|
Token::new(7u32, ".:".into(), (2, 4))
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
tokens,
|
||||||
|
vec![
|
||||||
|
Token {
|
||||||
|
id: 6,
|
||||||
|
value: "Ġ".into(),
|
||||||
|
offsets: (0, 2)
|
||||||
|
},
|
||||||
|
Token {
|
||||||
|
id: 4,
|
||||||
|
value: "bel".into(),
|
||||||
|
offsets: (2, 5)
|
||||||
|
},
|
||||||
|
Token {
|
||||||
|
id: 15,
|
||||||
|
value: "irtil".into(),
|
||||||
|
offsets: (5, 10)
|
||||||
|
},
|
||||||
|
Token {
|
||||||
|
id: 14,
|
||||||
|
value: "en".into(),
|
||||||
|
offsets: (10, 12)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user