mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
implement a simple max_sentencepiece_length into BPE (#1228)
* implement a simple max_sentencepiece_length into BPE Add a way for the BPE trainer to behave like the unigram trainer where tokens longer than a certain lenght(default 16 in SPM) to be skipped. this is implemented in unigram trainer but in a different way. If this code were to be actually integrated some works to be done Documentation describing the behavior and how it should be set. Set default==0 so it doesnt act unless set provide ways in the python binding for the user to set max token length I was trying to find a way to implement max_sentencepiece_length through pretokenizer split rules and to be honest, its very difficult and regexes can be real slow when operating on the whole training corpus. * implement a simple max_sentencepiece_length into BPE Add a way for the BPE trainer to behave like the unigram trainer where tokens longer than a certain lenght(default 16 in SPM) to be skipped. this is implemented in unigram trainer but in a different way. If this code were to be actually integrated some works to be done Documentation describing the behavior and how it should be set. Set default==0 so it doesnt act unless set provide ways in the python binding for the user to set max token length I was trying to find a way to implement max_sentencepiece_length through pretokenizer split rules and to be honest, its very difficult and regexes can be real slow when operating on the whole training corpus. * utilize Option<u16> for safer code. * Other version. * Update trainer.rs clarify with type usize propagate max_length option * change max_length into more descriptive name in the documentation https://huggingface.co/docs/tokenizers/api/trainers unigramtrainer uses max_piece_length for similar function. since BPE the underlying concept is merges, using max_merge_length as the variable name could prove more descriptive. * change variable name in trainer.rs change max_merge_length into max_token_length * Update trainer.rs add several max_token_length declaration that were missing. impl BpeTrainerBuilder struct BpeTrainer Add explanation for variable shadowing. * Update trainer.rs Move default definition of max_token_length to proper location. adjust downstream variable initializations accordingly. * add max_token_length test * Add bpe direct assert test * Update trainer.rs clarified test documentation * Creating the bindings. * Fix the default. * Re-adding missing package-lock which I accidentally removed. * .. * Fixing trainer test. * Fix. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
6593
bindings/node/package-lock.json
generated
Normal file
6593
bindings/node/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
@ -38,6 +38,12 @@ class BpeTrainer(Trainer):
|
|||||||
|
|
||||||
end_of_word_suffix (:obj:`str`, `optional`):
|
end_of_word_suffix (:obj:`str`, `optional`):
|
||||||
A suffix to be used for every subword that is a end-of-word.
|
A suffix to be used for every subword that is a end-of-word.
|
||||||
|
|
||||||
|
max_token_length (:obj:`int`, `optional`):
|
||||||
|
Prevents creating tokens longer than the specified size.
|
||||||
|
This can help with reducing polluting your vocabulary with
|
||||||
|
highly repetitive tokens like `======` for wikipedia
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class UnigramTrainer(Trainer):
|
class UnigramTrainer(Trainer):
|
||||||
|
@ -162,6 +162,12 @@ macro_rules! setter {
|
|||||||
///
|
///
|
||||||
/// end_of_word_suffix (:obj:`str`, `optional`):
|
/// end_of_word_suffix (:obj:`str`, `optional`):
|
||||||
/// A suffix to be used for every subword that is a end-of-word.
|
/// A suffix to be used for every subword that is a end-of-word.
|
||||||
|
///
|
||||||
|
/// max_token_length (:obj:`int`, `optional`):
|
||||||
|
/// Prevents creating tokens longer than the specified size.
|
||||||
|
/// This can help with reducing polluting your vocabulary with
|
||||||
|
/// highly repetitive tokens like `======` for wikipedia
|
||||||
|
///
|
||||||
#[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name = "BpeTrainer")]
|
#[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name = "BpeTrainer")]
|
||||||
pub struct PyBpeTrainer {}
|
pub struct PyBpeTrainer {}
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
@ -243,6 +249,16 @@ impl PyBpeTrainer {
|
|||||||
setter!(self_, BpeTrainer, limit_alphabet, limit);
|
setter!(self_, BpeTrainer, limit_alphabet, limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_max_token_length(self_: PyRef<Self>) -> Option<usize> {
|
||||||
|
getter!(self_, BpeTrainer, max_token_length)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_max_token_length(self_: PyRef<Self>, limit: Option<usize>) {
|
||||||
|
setter!(self_, BpeTrainer, max_token_length, limit);
|
||||||
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
|
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
|
||||||
getter!(
|
getter!(
|
||||||
@ -315,6 +331,7 @@ impl PyBpeTrainer {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
|
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
|
||||||
|
"max_token_length" => builder = builder.max_token_length(val.extract()?),
|
||||||
"initial_alphabet" => {
|
"initial_alphabet" => {
|
||||||
let alphabet: Vec<String> = val.extract()?;
|
let alphabet: Vec<String> = val.extract()?;
|
||||||
builder = builder.initial_alphabet(
|
builder = builder.initial_alphabet(
|
||||||
|
@ -63,7 +63,7 @@ class TestBpeTrainer:
|
|||||||
def test_can_pickle(self):
|
def test_can_pickle(self):
|
||||||
assert (
|
assert (
|
||||||
trainers.BpeTrainer(min_frequency=12).__getstate__()
|
trainers.BpeTrainer(min_frequency=12).__getstate__()
|
||||||
== b"""{"BpeTrainer":{"min_frequency":12,"vocab_size":30000,"show_progress":true,"special_tokens":[],"limit_alphabet":null,"initial_alphabet":[],"continuing_subword_prefix":null,"end_of_word_suffix":null,"words":{}}}"""
|
== b"""{"BpeTrainer":{"min_frequency":12,"vocab_size":30000,"show_progress":true,"special_tokens":[],"limit_alphabet":null,"initial_alphabet":[],"continuing_subword_prefix":null,"end_of_word_suffix":null,"max_token_length":null,"words":{}}}"""
|
||||||
)
|
)
|
||||||
assert isinstance(pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12))), trainers.BpeTrainer)
|
assert isinstance(pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12))), trainers.BpeTrainer)
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ struct Config {
|
|||||||
initial_alphabet: HashSet<char>,
|
initial_alphabet: HashSet<char>,
|
||||||
continuing_subword_prefix: Option<String>,
|
continuing_subword_prefix: Option<String>,
|
||||||
end_of_word_suffix: Option<String>,
|
end_of_word_suffix: Option<String>,
|
||||||
|
max_token_length: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom
|
/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom
|
||||||
@ -64,6 +65,7 @@ impl Default for BpeTrainerBuilder {
|
|||||||
initial_alphabet: HashSet::new(),
|
initial_alphabet: HashSet::new(),
|
||||||
continuing_subword_prefix: None,
|
continuing_subword_prefix: None,
|
||||||
end_of_word_suffix: None,
|
end_of_word_suffix: None,
|
||||||
|
max_token_length: None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -130,6 +132,12 @@ impl BpeTrainerBuilder {
|
|||||||
self.config.end_of_word_suffix = Some(suffix);
|
self.config.end_of_word_suffix = Some(suffix);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
/// Set max_token_length
|
||||||
|
#[must_use]
|
||||||
|
pub fn max_token_length(mut self, max_token_length: Option<usize>) -> Self {
|
||||||
|
self.config.max_token_length = max_token_length;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Constructs the final BpeTrainer
|
/// Constructs the final BpeTrainer
|
||||||
pub fn build(self) -> BpeTrainer {
|
pub fn build(self) -> BpeTrainer {
|
||||||
@ -142,6 +150,7 @@ impl BpeTrainerBuilder {
|
|||||||
initial_alphabet: self.config.initial_alphabet,
|
initial_alphabet: self.config.initial_alphabet,
|
||||||
continuing_subword_prefix: self.config.continuing_subword_prefix,
|
continuing_subword_prefix: self.config.continuing_subword_prefix,
|
||||||
end_of_word_suffix: self.config.end_of_word_suffix,
|
end_of_word_suffix: self.config.end_of_word_suffix,
|
||||||
|
max_token_length: self.config.max_token_length,
|
||||||
words: HashMap::new(),
|
words: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -183,6 +192,8 @@ pub struct BpeTrainer {
|
|||||||
pub continuing_subword_prefix: Option<String>,
|
pub continuing_subword_prefix: Option<String>,
|
||||||
/// An optional suffix to caracterize and end-of-word subword
|
/// An optional suffix to caracterize and end-of-word subword
|
||||||
pub end_of_word_suffix: Option<String>,
|
pub end_of_word_suffix: Option<String>,
|
||||||
|
/// An optional parameter to limit the max length of any single token
|
||||||
|
pub max_token_length: Option<usize>,
|
||||||
|
|
||||||
words: HashMap<String, u32>,
|
words: HashMap<String, u32>,
|
||||||
}
|
}
|
||||||
@ -425,6 +436,7 @@ impl BpeTrainer {
|
|||||||
) -> Result<Vec<AddedToken>> {
|
) -> Result<Vec<AddedToken>> {
|
||||||
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
|
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
|
||||||
let mut id_to_word: Vec<String> = Vec::with_capacity(self.vocab_size);
|
let mut id_to_word: Vec<String> = Vec::with_capacity(self.vocab_size);
|
||||||
|
let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX);
|
||||||
|
|
||||||
let progress = self.setup_progress();
|
let progress = self.setup_progress();
|
||||||
|
|
||||||
@ -502,6 +514,9 @@ impl BpeTrainer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
let new_token = format!("{}{}", part_a, part_b);
|
let new_token = format!("{}{}", part_a, part_b);
|
||||||
|
// implement sentencepiece-like merge.
|
||||||
|
// if this code were to be merged, integrate a way in the python bindings to communicate this variable
|
||||||
|
// default should be 0/None to maintain previous behavior. 16 is the spm default.
|
||||||
|
|
||||||
// Insert new token if it does not already exist
|
// Insert new token if it does not already exist
|
||||||
let new_token_id = word_to_id
|
let new_token_id = word_to_id
|
||||||
@ -524,7 +539,7 @@ impl BpeTrainer {
|
|||||||
// can be there only once (HashSet). So this is safe.
|
// can be there only once (HashSet). So this is safe.
|
||||||
unsafe {
|
unsafe {
|
||||||
let word: &mut Word = &mut (*w);
|
let word: &mut Word = &mut (*w);
|
||||||
word.merge(top.pair.0, top.pair.1, new_token_id)
|
word.merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|c| (c, *i))
|
.map(|c| (c, *i))
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
@ -720,4 +735,115 @@ mod tests {
|
|||||||
.collect();
|
.collect();
|
||||||
assert_eq!(model.merges, expected_merges);
|
assert_eq!(model.merges, expected_merges);
|
||||||
}
|
}
|
||||||
|
#[test]
|
||||||
|
fn bpe_test_max_token_length_16() {
|
||||||
|
/* bpe_test_max_token_length series of tests test the max_token_length flag of bpetrainer
|
||||||
|
// this is the more robust version that only tests max length of learned tokens
|
||||||
|
// (pre) tokenizer settings or vocab can be easily modified when necessary
|
||||||
|
*/
|
||||||
|
|
||||||
|
let max_token_length = 16;
|
||||||
|
let long_word_counts: HashMap<String, u32> = [
|
||||||
|
("singlelongtokenwithoutcasechange", 2),
|
||||||
|
("singleLongTokenWithCamelCaseChange", 2),
|
||||||
|
("Longsingletokenwithpunctu@t!onwithin", 2),
|
||||||
|
("Anotherlongsingletokenwithnumberw1th1n", 2),
|
||||||
|
("짧은한글문자열짧은한", 2), // korean 10 char
|
||||||
|
("긴한글문자열긴한글문자열긴한글문", 2), // korean 16 char
|
||||||
|
("短字符串短字符串短字", 2), //simplified chinese 10 char
|
||||||
|
("长字符串长字符串长字符串长字符串", 2), // simp. chinese 16 char
|
||||||
|
("短い文字列短い文字列", 2), // japanese 10 char
|
||||||
|
("長い文字列長い文字列長い文字列長", 2), // japanese 16 char
|
||||||
|
("so", 2),
|
||||||
|
("GPT-2", 2),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| (key.to_string(), *value))
|
||||||
|
.collect();
|
||||||
|
let trainer = BpeTrainer::builder()
|
||||||
|
.max_token_length(Some(max_token_length))
|
||||||
|
.show_progress(false)
|
||||||
|
.min_frequency(0)
|
||||||
|
.build();
|
||||||
|
let mut model = BPE::default();
|
||||||
|
trainer.do_train(&long_word_counts, &mut model).unwrap();
|
||||||
|
let vocab = model.get_vocab();
|
||||||
|
for token in vocab.keys() {
|
||||||
|
assert!(
|
||||||
|
token.chars().count() <= max_token_length,
|
||||||
|
"token too long : {} , chars().count() = {}",
|
||||||
|
token,
|
||||||
|
token.chars().count()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn bpe_test_max_token_length_direct_assert() {
|
||||||
|
/* more direct version of bpe_test_max_token_length test
|
||||||
|
// directly compares tokens with known expected values.
|
||||||
|
// maybe unstable depending on specific settings or changes.
|
||||||
|
*/
|
||||||
|
let long_word_counts: HashMap<String, u32> = [
|
||||||
|
("sin", 2),
|
||||||
|
("Sin", 2),
|
||||||
|
("Lon", 2),
|
||||||
|
("Ano", 2),
|
||||||
|
("짧은한", 2),
|
||||||
|
("긴한글", 2),
|
||||||
|
("短字符", 2),
|
||||||
|
("长字符", 2),
|
||||||
|
("短い文", 2),
|
||||||
|
("長い文", 2),
|
||||||
|
("so", 2),
|
||||||
|
("GP", 2),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| (key.to_string(), *value))
|
||||||
|
.collect();
|
||||||
|
let trainer = BpeTrainer::builder()
|
||||||
|
.max_token_length(Some(2))
|
||||||
|
.show_progress(false)
|
||||||
|
.min_frequency(0)
|
||||||
|
.build();
|
||||||
|
let mut model = BPE::default();
|
||||||
|
trainer.do_train(&long_word_counts, &mut model).unwrap();
|
||||||
|
let trained_vocab: HashMap<String, u32> = model.get_vocab();
|
||||||
|
let expected_vocab: HashMap<String, u32> = [
|
||||||
|
("短", 12),
|
||||||
|
("n", 6),
|
||||||
|
("i", 5),
|
||||||
|
("s", 8),
|
||||||
|
("字符", 23),
|
||||||
|
("長", 14),
|
||||||
|
("긴", 17),
|
||||||
|
("い文", 22),
|
||||||
|
("L", 2),
|
||||||
|
("in", 21),
|
||||||
|
("o", 7),
|
||||||
|
("은한", 29),
|
||||||
|
("S", 4),
|
||||||
|
("P", 3),
|
||||||
|
("so", 27),
|
||||||
|
("符", 13),
|
||||||
|
("文", 11),
|
||||||
|
("字", 10),
|
||||||
|
("짧", 19),
|
||||||
|
("GP", 25),
|
||||||
|
("글", 16),
|
||||||
|
("G", 1),
|
||||||
|
("An", 24),
|
||||||
|
("长", 15),
|
||||||
|
("A", 0),
|
||||||
|
("Lo", 26),
|
||||||
|
("긴한", 28),
|
||||||
|
("い", 9),
|
||||||
|
("한", 20),
|
||||||
|
("은", 18),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.map(|(k, v)| (k.to_string(), v))
|
||||||
|
.collect();
|
||||||
|
assert_eq!(trained_vocab, expected_vocab)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -103,7 +103,13 @@ impl Word {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn merge(&mut self, c1: u32, c2: u32, replacement: u32) -> Vec<(Pair, i32)> {
|
pub(super) fn merge(
|
||||||
|
&mut self,
|
||||||
|
c1: u32,
|
||||||
|
c2: u32,
|
||||||
|
replacement: u32,
|
||||||
|
max_length: usize,
|
||||||
|
) -> Vec<(Pair, i32)> {
|
||||||
let mut changes: Vec<(Pair, i32)> = vec![];
|
let mut changes: Vec<(Pair, i32)> = vec![];
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
loop {
|
loop {
|
||||||
@ -117,12 +123,6 @@ impl Word {
|
|||||||
let first = self.symbols[i];
|
let first = self.symbols[i];
|
||||||
let second = self.symbols[i + 1];
|
let second = self.symbols[i + 1];
|
||||||
|
|
||||||
// If there are other characters before the pair
|
|
||||||
if i > 0 {
|
|
||||||
changes.push(((self.symbols[i - 1].c, first.c), -1));
|
|
||||||
changes.push(((self.symbols[i - 1].c, replacement), 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove in place
|
// Remove in place
|
||||||
let new_s = Symbol {
|
let new_s = Symbol {
|
||||||
c: replacement,
|
c: replacement,
|
||||||
@ -130,6 +130,15 @@ impl Word {
|
|||||||
next: second.next,
|
next: second.next,
|
||||||
len: first.len + second.len,
|
len: first.len + second.len,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// If there are other characters before the pair
|
||||||
|
if i > 0 {
|
||||||
|
changes.push(((self.symbols[i - 1].c, first.c), -1));
|
||||||
|
if self.symbols[i - 1].len + new_s.len < max_length {
|
||||||
|
changes.push(((self.symbols[i - 1].c, replacement), 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
self.symbols.insert(i, new_s); // Insert replacement before first char of pair
|
self.symbols.insert(i, new_s); // Insert replacement before first char of pair
|
||||||
self.symbols.remove(i + 1); // Remove first char of pair
|
self.symbols.remove(i + 1); // Remove first char of pair
|
||||||
self.symbols.remove(i + 1); // And then the second
|
self.symbols.remove(i + 1); // And then the second
|
||||||
@ -137,7 +146,9 @@ impl Word {
|
|||||||
// If there are other characters after the pair
|
// If there are other characters after the pair
|
||||||
if i < self.symbols.len() - 1 {
|
if i < self.symbols.len() - 1 {
|
||||||
changes.push(((second.c, self.symbols[i + 1].c), -1));
|
changes.push(((second.c, self.symbols[i + 1].c), -1));
|
||||||
changes.push(((replacement, self.symbols[i + 1].c), 1));
|
if self.symbols[i + 1].len + new_s.len < max_length {
|
||||||
|
changes.push(((replacement, self.symbols[i + 1].c), 1));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -276,7 +287,7 @@ mod tests {
|
|||||||
|
|
||||||
// We're going to perform a merge on the pair ('l', 'l') ~= (2, 2). Let's
|
// We're going to perform a merge on the pair ('l', 'l') ~= (2, 2). Let's
|
||||||
// say that 'll' has the ID of 4 in the updated word-to-id vocab.
|
// say that 'll' has the ID of 4 in the updated word-to-id vocab.
|
||||||
let changes = word.merge(2, 2, 4);
|
let changes = word.merge(2, 2, 4, usize::MAX);
|
||||||
|
|
||||||
// So the word should now look like this:
|
// So the word should now look like this:
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -306,4 +317,39 @@ mod tests {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_merge_max_length() {
|
||||||
|
// Let's say we have the word 'hello' and a word-to-id vocab that looks
|
||||||
|
// like this: {'h': 0, 'e': 1, 'l': 2, 'o': 3}.
|
||||||
|
let mut word = Word::new();
|
||||||
|
word.add(0, 1); // 'h'
|
||||||
|
word.add(1, 1); // 'e'
|
||||||
|
word.add(2, 1); // 'l'
|
||||||
|
word.add(2, 1); // 'l'
|
||||||
|
word.add(3, 1); // 'o'
|
||||||
|
|
||||||
|
// We're going to perform a merge on the pair ('l', 'l') ~= (2, 2). Let's
|
||||||
|
// say that 'll' has the ID of 4 in the updated word-to-id vocab.
|
||||||
|
let changes = word.merge(2, 2, 4, 2);
|
||||||
|
assert_eq!(
|
||||||
|
word.get_chars(),
|
||||||
|
&[
|
||||||
|
0u32, // 'h'
|
||||||
|
1u32, // 'e'
|
||||||
|
4u32, // 'll'
|
||||||
|
3u32, // 'o'
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
changes,
|
||||||
|
&[
|
||||||
|
((1u32, 2u32), -1i32), // count for ('e', 'l') should be decreased by 1.
|
||||||
|
// ((1u32, 4u32), 1i32), Missing since this would be larger than 2
|
||||||
|
((2u32, 3u32), -1i32), // count for ('l', 'o') should be decreased by 1.
|
||||||
|
// ((4u32, 3u32), 1i32), Missing since this would be larger than 2
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user