mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
implement a simple max_sentencepiece_length into BPE (#1228)
* implement a simple max_sentencepiece_length into BPE Add a way for the BPE trainer to behave like the unigram trainer where tokens longer than a certain lenght(default 16 in SPM) to be skipped. this is implemented in unigram trainer but in a different way. If this code were to be actually integrated some works to be done Documentation describing the behavior and how it should be set. Set default==0 so it doesnt act unless set provide ways in the python binding for the user to set max token length I was trying to find a way to implement max_sentencepiece_length through pretokenizer split rules and to be honest, its very difficult and regexes can be real slow when operating on the whole training corpus. * implement a simple max_sentencepiece_length into BPE Add a way for the BPE trainer to behave like the unigram trainer where tokens longer than a certain lenght(default 16 in SPM) to be skipped. this is implemented in unigram trainer but in a different way. If this code were to be actually integrated some works to be done Documentation describing the behavior and how it should be set. Set default==0 so it doesnt act unless set provide ways in the python binding for the user to set max token length I was trying to find a way to implement max_sentencepiece_length through pretokenizer split rules and to be honest, its very difficult and regexes can be real slow when operating on the whole training corpus. * utilize Option<u16> for safer code. * Other version. * Update trainer.rs clarify with type usize propagate max_length option * change max_length into more descriptive name in the documentation https://huggingface.co/docs/tokenizers/api/trainers unigramtrainer uses max_piece_length for similar function. since BPE the underlying concept is merges, using max_merge_length as the variable name could prove more descriptive. * change variable name in trainer.rs change max_merge_length into max_token_length * Update trainer.rs add several max_token_length declaration that were missing. impl BpeTrainerBuilder struct BpeTrainer Add explanation for variable shadowing. * Update trainer.rs Move default definition of max_token_length to proper location. adjust downstream variable initializations accordingly. * add max_token_length test * Add bpe direct assert test * Update trainer.rs clarified test documentation * Creating the bindings. * Fix the default. * Re-adding missing package-lock which I accidentally removed. * .. * Fixing trainer test. * Fix. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
6593
bindings/node/package-lock.json
generated
Normal file
6593
bindings/node/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
@ -38,6 +38,12 @@ class BpeTrainer(Trainer):
|
||||
|
||||
end_of_word_suffix (:obj:`str`, `optional`):
|
||||
A suffix to be used for every subword that is a end-of-word.
|
||||
|
||||
max_token_length (:obj:`int`, `optional`):
|
||||
Prevents creating tokens longer than the specified size.
|
||||
This can help with reducing polluting your vocabulary with
|
||||
highly repetitive tokens like `======` for wikipedia
|
||||
|
||||
"""
|
||||
|
||||
class UnigramTrainer(Trainer):
|
||||
|
@ -162,6 +162,12 @@ macro_rules! setter {
|
||||
///
|
||||
/// end_of_word_suffix (:obj:`str`, `optional`):
|
||||
/// A suffix to be used for every subword that is a end-of-word.
|
||||
///
|
||||
/// max_token_length (:obj:`int`, `optional`):
|
||||
/// Prevents creating tokens longer than the specified size.
|
||||
/// This can help with reducing polluting your vocabulary with
|
||||
/// highly repetitive tokens like `======` for wikipedia
|
||||
///
|
||||
#[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name = "BpeTrainer")]
|
||||
pub struct PyBpeTrainer {}
|
||||
#[pymethods]
|
||||
@ -243,6 +249,16 @@ impl PyBpeTrainer {
|
||||
setter!(self_, BpeTrainer, limit_alphabet, limit);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_max_token_length(self_: PyRef<Self>) -> Option<usize> {
|
||||
getter!(self_, BpeTrainer, max_token_length)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_max_token_length(self_: PyRef<Self>, limit: Option<usize>) {
|
||||
setter!(self_, BpeTrainer, max_token_length, limit);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
|
||||
getter!(
|
||||
@ -315,6 +331,7 @@ impl PyBpeTrainer {
|
||||
);
|
||||
}
|
||||
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
|
||||
"max_token_length" => builder = builder.max_token_length(val.extract()?),
|
||||
"initial_alphabet" => {
|
||||
let alphabet: Vec<String> = val.extract()?;
|
||||
builder = builder.initial_alphabet(
|
||||
|
@ -63,7 +63,7 @@ class TestBpeTrainer:
|
||||
def test_can_pickle(self):
|
||||
assert (
|
||||
trainers.BpeTrainer(min_frequency=12).__getstate__()
|
||||
== b"""{"BpeTrainer":{"min_frequency":12,"vocab_size":30000,"show_progress":true,"special_tokens":[],"limit_alphabet":null,"initial_alphabet":[],"continuing_subword_prefix":null,"end_of_word_suffix":null,"words":{}}}"""
|
||||
== b"""{"BpeTrainer":{"min_frequency":12,"vocab_size":30000,"show_progress":true,"special_tokens":[],"limit_alphabet":null,"initial_alphabet":[],"continuing_subword_prefix":null,"end_of_word_suffix":null,"max_token_length":null,"words":{}}}"""
|
||||
)
|
||||
assert isinstance(pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12))), trainers.BpeTrainer)
|
||||
|
||||
|
@ -44,6 +44,7 @@ struct Config {
|
||||
initial_alphabet: HashSet<char>,
|
||||
continuing_subword_prefix: Option<String>,
|
||||
end_of_word_suffix: Option<String>,
|
||||
max_token_length: Option<usize>,
|
||||
}
|
||||
|
||||
/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom
|
||||
@ -64,6 +65,7 @@ impl Default for BpeTrainerBuilder {
|
||||
initial_alphabet: HashSet::new(),
|
||||
continuing_subword_prefix: None,
|
||||
end_of_word_suffix: None,
|
||||
max_token_length: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -130,6 +132,12 @@ impl BpeTrainerBuilder {
|
||||
self.config.end_of_word_suffix = Some(suffix);
|
||||
self
|
||||
}
|
||||
/// Set max_token_length
|
||||
#[must_use]
|
||||
pub fn max_token_length(mut self, max_token_length: Option<usize>) -> Self {
|
||||
self.config.max_token_length = max_token_length;
|
||||
self
|
||||
}
|
||||
|
||||
/// Constructs the final BpeTrainer
|
||||
pub fn build(self) -> BpeTrainer {
|
||||
@ -142,6 +150,7 @@ impl BpeTrainerBuilder {
|
||||
initial_alphabet: self.config.initial_alphabet,
|
||||
continuing_subword_prefix: self.config.continuing_subword_prefix,
|
||||
end_of_word_suffix: self.config.end_of_word_suffix,
|
||||
max_token_length: self.config.max_token_length,
|
||||
words: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@ -183,6 +192,8 @@ pub struct BpeTrainer {
|
||||
pub continuing_subword_prefix: Option<String>,
|
||||
/// An optional suffix to caracterize and end-of-word subword
|
||||
pub end_of_word_suffix: Option<String>,
|
||||
/// An optional parameter to limit the max length of any single token
|
||||
pub max_token_length: Option<usize>,
|
||||
|
||||
words: HashMap<String, u32>,
|
||||
}
|
||||
@ -425,6 +436,7 @@ impl BpeTrainer {
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
|
||||
let mut id_to_word: Vec<String> = Vec::with_capacity(self.vocab_size);
|
||||
let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX);
|
||||
|
||||
let progress = self.setup_progress();
|
||||
|
||||
@ -502,6 +514,9 @@ impl BpeTrainer {
|
||||
}
|
||||
}
|
||||
let new_token = format!("{}{}", part_a, part_b);
|
||||
// implement sentencepiece-like merge.
|
||||
// if this code were to be merged, integrate a way in the python bindings to communicate this variable
|
||||
// default should be 0/None to maintain previous behavior. 16 is the spm default.
|
||||
|
||||
// Insert new token if it does not already exist
|
||||
let new_token_id = word_to_id
|
||||
@ -524,7 +539,7 @@ impl BpeTrainer {
|
||||
// can be there only once (HashSet). So this is safe.
|
||||
unsafe {
|
||||
let word: &mut Word = &mut (*w);
|
||||
word.merge(top.pair.0, top.pair.1, new_token_id)
|
||||
word.merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
|
||||
.into_iter()
|
||||
.map(|c| (c, *i))
|
||||
.collect::<Vec<_>>()
|
||||
@ -720,4 +735,115 @@ mod tests {
|
||||
.collect();
|
||||
assert_eq!(model.merges, expected_merges);
|
||||
}
|
||||
#[test]
|
||||
fn bpe_test_max_token_length_16() {
|
||||
/* bpe_test_max_token_length series of tests test the max_token_length flag of bpetrainer
|
||||
// this is the more robust version that only tests max length of learned tokens
|
||||
// (pre) tokenizer settings or vocab can be easily modified when necessary
|
||||
*/
|
||||
|
||||
let max_token_length = 16;
|
||||
let long_word_counts: HashMap<String, u32> = [
|
||||
("singlelongtokenwithoutcasechange", 2),
|
||||
("singleLongTokenWithCamelCaseChange", 2),
|
||||
("Longsingletokenwithpunctu@t!onwithin", 2),
|
||||
("Anotherlongsingletokenwithnumberw1th1n", 2),
|
||||
("짧은한글문자열짧은한", 2), // korean 10 char
|
||||
("긴한글문자열긴한글문자열긴한글문", 2), // korean 16 char
|
||||
("短字符串短字符串短字", 2), //simplified chinese 10 char
|
||||
("长字符串长字符串长字符串长字符串", 2), // simp. chinese 16 char
|
||||
("短い文字列短い文字列", 2), // japanese 10 char
|
||||
("長い文字列長い文字列長い文字列長", 2), // japanese 16 char
|
||||
("so", 2),
|
||||
("GPT-2", 2),
|
||||
]
|
||||
.iter()
|
||||
.map(|(key, value)| (key.to_string(), *value))
|
||||
.collect();
|
||||
let trainer = BpeTrainer::builder()
|
||||
.max_token_length(Some(max_token_length))
|
||||
.show_progress(false)
|
||||
.min_frequency(0)
|
||||
.build();
|
||||
let mut model = BPE::default();
|
||||
trainer.do_train(&long_word_counts, &mut model).unwrap();
|
||||
let vocab = model.get_vocab();
|
||||
for token in vocab.keys() {
|
||||
assert!(
|
||||
token.chars().count() <= max_token_length,
|
||||
"token too long : {} , chars().count() = {}",
|
||||
token,
|
||||
token.chars().count()
|
||||
)
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn bpe_test_max_token_length_direct_assert() {
|
||||
/* more direct version of bpe_test_max_token_length test
|
||||
// directly compares tokens with known expected values.
|
||||
// maybe unstable depending on specific settings or changes.
|
||||
*/
|
||||
let long_word_counts: HashMap<String, u32> = [
|
||||
("sin", 2),
|
||||
("Sin", 2),
|
||||
("Lon", 2),
|
||||
("Ano", 2),
|
||||
("짧은한", 2),
|
||||
("긴한글", 2),
|
||||
("短字符", 2),
|
||||
("长字符", 2),
|
||||
("短い文", 2),
|
||||
("長い文", 2),
|
||||
("so", 2),
|
||||
("GP", 2),
|
||||
]
|
||||
.iter()
|
||||
.map(|(key, value)| (key.to_string(), *value))
|
||||
.collect();
|
||||
let trainer = BpeTrainer::builder()
|
||||
.max_token_length(Some(2))
|
||||
.show_progress(false)
|
||||
.min_frequency(0)
|
||||
.build();
|
||||
let mut model = BPE::default();
|
||||
trainer.do_train(&long_word_counts, &mut model).unwrap();
|
||||
let trained_vocab: HashMap<String, u32> = model.get_vocab();
|
||||
let expected_vocab: HashMap<String, u32> = [
|
||||
("短", 12),
|
||||
("n", 6),
|
||||
("i", 5),
|
||||
("s", 8),
|
||||
("字符", 23),
|
||||
("長", 14),
|
||||
("긴", 17),
|
||||
("い文", 22),
|
||||
("L", 2),
|
||||
("in", 21),
|
||||
("o", 7),
|
||||
("은한", 29),
|
||||
("S", 4),
|
||||
("P", 3),
|
||||
("so", 27),
|
||||
("符", 13),
|
||||
("文", 11),
|
||||
("字", 10),
|
||||
("짧", 19),
|
||||
("GP", 25),
|
||||
("글", 16),
|
||||
("G", 1),
|
||||
("An", 24),
|
||||
("长", 15),
|
||||
("A", 0),
|
||||
("Lo", 26),
|
||||
("긴한", 28),
|
||||
("い", 9),
|
||||
("한", 20),
|
||||
("은", 18),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|(k, v)| (k.to_string(), v))
|
||||
.collect();
|
||||
assert_eq!(trained_vocab, expected_vocab)
|
||||
}
|
||||
}
|
||||
|
@ -103,7 +103,13 @@ impl Word {
|
||||
});
|
||||
}
|
||||
|
||||
pub(super) fn merge(&mut self, c1: u32, c2: u32, replacement: u32) -> Vec<(Pair, i32)> {
|
||||
pub(super) fn merge(
|
||||
&mut self,
|
||||
c1: u32,
|
||||
c2: u32,
|
||||
replacement: u32,
|
||||
max_length: usize,
|
||||
) -> Vec<(Pair, i32)> {
|
||||
let mut changes: Vec<(Pair, i32)> = vec![];
|
||||
let mut i = 0;
|
||||
loop {
|
||||
@ -117,12 +123,6 @@ impl Word {
|
||||
let first = self.symbols[i];
|
||||
let second = self.symbols[i + 1];
|
||||
|
||||
// If there are other characters before the pair
|
||||
if i > 0 {
|
||||
changes.push(((self.symbols[i - 1].c, first.c), -1));
|
||||
changes.push(((self.symbols[i - 1].c, replacement), 1));
|
||||
}
|
||||
|
||||
// Remove in place
|
||||
let new_s = Symbol {
|
||||
c: replacement,
|
||||
@ -130,6 +130,15 @@ impl Word {
|
||||
next: second.next,
|
||||
len: first.len + second.len,
|
||||
};
|
||||
|
||||
// If there are other characters before the pair
|
||||
if i > 0 {
|
||||
changes.push(((self.symbols[i - 1].c, first.c), -1));
|
||||
if self.symbols[i - 1].len + new_s.len < max_length {
|
||||
changes.push(((self.symbols[i - 1].c, replacement), 1));
|
||||
}
|
||||
}
|
||||
|
||||
self.symbols.insert(i, new_s); // Insert replacement before first char of pair
|
||||
self.symbols.remove(i + 1); // Remove first char of pair
|
||||
self.symbols.remove(i + 1); // And then the second
|
||||
@ -137,9 +146,11 @@ impl Word {
|
||||
// If there are other characters after the pair
|
||||
if i < self.symbols.len() - 1 {
|
||||
changes.push(((second.c, self.symbols[i + 1].c), -1));
|
||||
if self.symbols[i + 1].len + new_s.len < max_length {
|
||||
changes.push(((replacement, self.symbols[i + 1].c), 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
@ -276,7 +287,7 @@ mod tests {
|
||||
|
||||
// We're going to perform a merge on the pair ('l', 'l') ~= (2, 2). Let's
|
||||
// say that 'll' has the ID of 4 in the updated word-to-id vocab.
|
||||
let changes = word.merge(2, 2, 4);
|
||||
let changes = word.merge(2, 2, 4, usize::MAX);
|
||||
|
||||
// So the word should now look like this:
|
||||
assert_eq!(
|
||||
@ -306,4 +317,39 @@ mod tests {
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_max_length() {
|
||||
// Let's say we have the word 'hello' and a word-to-id vocab that looks
|
||||
// like this: {'h': 0, 'e': 1, 'l': 2, 'o': 3}.
|
||||
let mut word = Word::new();
|
||||
word.add(0, 1); // 'h'
|
||||
word.add(1, 1); // 'e'
|
||||
word.add(2, 1); // 'l'
|
||||
word.add(2, 1); // 'l'
|
||||
word.add(3, 1); // 'o'
|
||||
|
||||
// We're going to perform a merge on the pair ('l', 'l') ~= (2, 2). Let's
|
||||
// say that 'll' has the ID of 4 in the updated word-to-id vocab.
|
||||
let changes = word.merge(2, 2, 4, 2);
|
||||
assert_eq!(
|
||||
word.get_chars(),
|
||||
&[
|
||||
0u32, // 'h'
|
||||
1u32, // 'e'
|
||||
4u32, // 'll'
|
||||
3u32, // 'o'
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
changes,
|
||||
&[
|
||||
((1u32, 2u32), -1i32), // count for ('e', 'l') should be decreased by 1.
|
||||
// ((1u32, 4u32), 1i32), Missing since this would be larger than 2
|
||||
((2u32, 3u32), -1i32), // count for ('l', 'o') should be decreased by 1.
|
||||
// ((4u32, 3u32), 1i32), Missing since this would be larger than 2
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user