implement a simple max_sentencepiece_length into BPE (#1228)

* implement a simple max_sentencepiece_length into BPE

Add a way for the BPE trainer to behave like the unigram trainer where tokens longer than a certain lenght(default 16 in SPM) to be skipped. this is implemented in unigram trainer but in a different way.

If this code were to be actually integrated some works to be done

Documentation describing the behavior and how it should be set.
Set default==0 so it doesnt act unless set
provide ways in the python binding for the user to set max token length

I was trying to find a way to implement max_sentencepiece_length through pretokenizer split rules and to be honest, its very difficult and regexes can be real slow when operating on the whole training corpus.

* implement a simple max_sentencepiece_length into BPE

Add a way for the BPE trainer to behave like the unigram trainer where tokens longer than a certain lenght(default 16 in SPM) to be skipped. this is implemented in unigram trainer but in a different way.

If this code were to be actually integrated some works to be done

Documentation describing the behavior and how it should be set.
Set default==0 so it doesnt act unless set
provide ways in the python binding for the user to set max token length

I was trying to find a way to implement max_sentencepiece_length through pretokenizer split rules and to be honest, its very difficult and regexes can be real slow when operating on the whole training corpus.

* utilize Option<u16> for safer code.

* Other version.

* Update trainer.rs

clarify with type usize propagate max_length option

* change max_length into more descriptive name

in the documentation
https://huggingface.co/docs/tokenizers/api/trainers
unigramtrainer uses max_piece_length for similar function.
since BPE the underlying concept is merges, using max_merge_length as the variable name could prove more descriptive.

* change variable name in trainer.rs

change max_merge_length into max_token_length

* Update trainer.rs

add several max_token_length declaration that were missing.
impl BpeTrainerBuilder
struct BpeTrainer

Add explanation for variable shadowing.

* Update trainer.rs

Move default definition of max_token_length to proper location. adjust downstream variable initializations accordingly.

* add max_token_length test

* Add bpe direct assert test

* Update trainer.rs

clarified test documentation

* Creating the bindings.

* Fix the default.

* Re-adding missing package-lock which I accidentally removed.

* ..

* Fixing trainer test.

* Fix.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Chris Ha
2023-05-16 17:08:19 +09:00
committed by GitHub
parent daf3fcc976
commit cefc41e8ec
6 changed files with 6799 additions and 11 deletions

6593
bindings/node/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -38,6 +38,12 @@ class BpeTrainer(Trainer):
end_of_word_suffix (:obj:`str`, `optional`):
A suffix to be used for every subword that is a end-of-word.
max_token_length (:obj:`int`, `optional`):
Prevents creating tokens longer than the specified size.
This can help with reducing polluting your vocabulary with
highly repetitive tokens like `======` for wikipedia
"""
class UnigramTrainer(Trainer):

View File

@ -162,6 +162,12 @@ macro_rules! setter {
///
/// end_of_word_suffix (:obj:`str`, `optional`):
/// A suffix to be used for every subword that is a end-of-word.
///
/// max_token_length (:obj:`int`, `optional`):
/// Prevents creating tokens longer than the specified size.
/// This can help with reducing polluting your vocabulary with
/// highly repetitive tokens like `======` for wikipedia
///
#[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name = "BpeTrainer")]
pub struct PyBpeTrainer {}
#[pymethods]
@ -243,6 +249,16 @@ impl PyBpeTrainer {
setter!(self_, BpeTrainer, limit_alphabet, limit);
}
#[getter]
fn get_max_token_length(self_: PyRef<Self>) -> Option<usize> {
getter!(self_, BpeTrainer, max_token_length)
}
#[setter]
fn set_max_token_length(self_: PyRef<Self>, limit: Option<usize>) {
setter!(self_, BpeTrainer, max_token_length, limit);
}
#[getter]
fn get_initial_alphabet(self_: PyRef<Self>) -> Vec<String> {
getter!(
@ -315,6 +331,7 @@ impl PyBpeTrainer {
);
}
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
"max_token_length" => builder = builder.max_token_length(val.extract()?),
"initial_alphabet" => {
let alphabet: Vec<String> = val.extract()?;
builder = builder.initial_alphabet(

View File

@ -63,7 +63,7 @@ class TestBpeTrainer:
def test_can_pickle(self):
assert (
trainers.BpeTrainer(min_frequency=12).__getstate__()
== b"""{"BpeTrainer":{"min_frequency":12,"vocab_size":30000,"show_progress":true,"special_tokens":[],"limit_alphabet":null,"initial_alphabet":[],"continuing_subword_prefix":null,"end_of_word_suffix":null,"words":{}}}"""
== b"""{"BpeTrainer":{"min_frequency":12,"vocab_size":30000,"show_progress":true,"special_tokens":[],"limit_alphabet":null,"initial_alphabet":[],"continuing_subword_prefix":null,"end_of_word_suffix":null,"max_token_length":null,"words":{}}}"""
)
assert isinstance(pickle.loads(pickle.dumps(trainers.BpeTrainer(min_frequency=12))), trainers.BpeTrainer)

View File

@ -44,6 +44,7 @@ struct Config {
initial_alphabet: HashSet<char>,
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
max_token_length: Option<usize>,
}
/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom
@ -64,6 +65,7 @@ impl Default for BpeTrainerBuilder {
initial_alphabet: HashSet::new(),
continuing_subword_prefix: None,
end_of_word_suffix: None,
max_token_length: None,
},
}
}
@ -130,6 +132,12 @@ impl BpeTrainerBuilder {
self.config.end_of_word_suffix = Some(suffix);
self
}
/// Set max_token_length
#[must_use]
pub fn max_token_length(mut self, max_token_length: Option<usize>) -> Self {
self.config.max_token_length = max_token_length;
self
}
/// Constructs the final BpeTrainer
pub fn build(self) -> BpeTrainer {
@ -142,6 +150,7 @@ impl BpeTrainerBuilder {
initial_alphabet: self.config.initial_alphabet,
continuing_subword_prefix: self.config.continuing_subword_prefix,
end_of_word_suffix: self.config.end_of_word_suffix,
max_token_length: self.config.max_token_length,
words: HashMap::new(),
}
}
@ -183,6 +192,8 @@ pub struct BpeTrainer {
pub continuing_subword_prefix: Option<String>,
/// An optional suffix to caracterize and end-of-word subword
pub end_of_word_suffix: Option<String>,
/// An optional parameter to limit the max length of any single token
pub max_token_length: Option<usize>,
words: HashMap<String, u32>,
}
@ -425,6 +436,7 @@ impl BpeTrainer {
) -> Result<Vec<AddedToken>> {
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
let mut id_to_word: Vec<String> = Vec::with_capacity(self.vocab_size);
let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX);
let progress = self.setup_progress();
@ -502,6 +514,9 @@ impl BpeTrainer {
}
}
let new_token = format!("{}{}", part_a, part_b);
// implement sentencepiece-like merge.
// if this code were to be merged, integrate a way in the python bindings to communicate this variable
// default should be 0/None to maintain previous behavior. 16 is the spm default.
// Insert new token if it does not already exist
let new_token_id = word_to_id
@ -524,7 +539,7 @@ impl BpeTrainer {
// can be there only once (HashSet). So this is safe.
unsafe {
let word: &mut Word = &mut (*w);
word.merge(top.pair.0, top.pair.1, new_token_id)
word.merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
.into_iter()
.map(|c| (c, *i))
.collect::<Vec<_>>()
@ -720,4 +735,115 @@ mod tests {
.collect();
assert_eq!(model.merges, expected_merges);
}
#[test]
fn bpe_test_max_token_length_16() {
/* bpe_test_max_token_length series of tests test the max_token_length flag of bpetrainer
// this is the more robust version that only tests max length of learned tokens
// (pre) tokenizer settings or vocab can be easily modified when necessary
*/
let max_token_length = 16;
let long_word_counts: HashMap<String, u32> = [
("singlelongtokenwithoutcasechange", 2),
("singleLongTokenWithCamelCaseChange", 2),
("Longsingletokenwithpunctu@t!onwithin", 2),
("Anotherlongsingletokenwithnumberw1th1n", 2),
("짧은한글문자열짧은한", 2), // korean 10 char
("긴한글문자열긴한글문자열긴한글문", 2), // korean 16 char
("短字符串短字符串短字", 2), //simplified chinese 10 char
("长字符串长字符串长字符串长字符串", 2), // simp. chinese 16 char
("短い文字列短い文字列", 2), // japanese 10 char
("長い文字列長い文字列長い文字列長", 2), // japanese 16 char
("so", 2),
("GPT-2", 2),
]
.iter()
.map(|(key, value)| (key.to_string(), *value))
.collect();
let trainer = BpeTrainer::builder()
.max_token_length(Some(max_token_length))
.show_progress(false)
.min_frequency(0)
.build();
let mut model = BPE::default();
trainer.do_train(&long_word_counts, &mut model).unwrap();
let vocab = model.get_vocab();
for token in vocab.keys() {
assert!(
token.chars().count() <= max_token_length,
"token too long : {} , chars().count() = {}",
token,
token.chars().count()
)
}
}
#[test]
fn bpe_test_max_token_length_direct_assert() {
/* more direct version of bpe_test_max_token_length test
// directly compares tokens with known expected values.
// maybe unstable depending on specific settings or changes.
*/
let long_word_counts: HashMap<String, u32> = [
("sin", 2),
("Sin", 2),
("Lon", 2),
("Ano", 2),
("짧은한", 2),
("긴한글", 2),
("短字符", 2),
("长字符", 2),
("短い文", 2),
("長い文", 2),
("so", 2),
("GP", 2),
]
.iter()
.map(|(key, value)| (key.to_string(), *value))
.collect();
let trainer = BpeTrainer::builder()
.max_token_length(Some(2))
.show_progress(false)
.min_frequency(0)
.build();
let mut model = BPE::default();
trainer.do_train(&long_word_counts, &mut model).unwrap();
let trained_vocab: HashMap<String, u32> = model.get_vocab();
let expected_vocab: HashMap<String, u32> = [
("", 12),
("n", 6),
("i", 5),
("s", 8),
("字符", 23),
("", 14),
("", 17),
("い文", 22),
("L", 2),
("in", 21),
("o", 7),
("은한", 29),
("S", 4),
("P", 3),
("so", 27),
("", 13),
("", 11),
("", 10),
("", 19),
("GP", 25),
("", 16),
("G", 1),
("An", 24),
("", 15),
("A", 0),
("Lo", 26),
("긴한", 28),
("", 9),
("", 20),
("", 18),
]
.iter()
.cloned()
.map(|(k, v)| (k.to_string(), v))
.collect();
assert_eq!(trained_vocab, expected_vocab)
}
}

View File

@ -103,7 +103,13 @@ impl Word {
});
}
pub(super) fn merge(&mut self, c1: u32, c2: u32, replacement: u32) -> Vec<(Pair, i32)> {
pub(super) fn merge(
&mut self,
c1: u32,
c2: u32,
replacement: u32,
max_length: usize,
) -> Vec<(Pair, i32)> {
let mut changes: Vec<(Pair, i32)> = vec![];
let mut i = 0;
loop {
@ -117,12 +123,6 @@ impl Word {
let first = self.symbols[i];
let second = self.symbols[i + 1];
// If there are other characters before the pair
if i > 0 {
changes.push(((self.symbols[i - 1].c, first.c), -1));
changes.push(((self.symbols[i - 1].c, replacement), 1));
}
// Remove in place
let new_s = Symbol {
c: replacement,
@ -130,6 +130,15 @@ impl Word {
next: second.next,
len: first.len + second.len,
};
// If there are other characters before the pair
if i > 0 {
changes.push(((self.symbols[i - 1].c, first.c), -1));
if self.symbols[i - 1].len + new_s.len < max_length {
changes.push(((self.symbols[i - 1].c, replacement), 1));
}
}
self.symbols.insert(i, new_s); // Insert replacement before first char of pair
self.symbols.remove(i + 1); // Remove first char of pair
self.symbols.remove(i + 1); // And then the second
@ -137,7 +146,9 @@ impl Word {
// If there are other characters after the pair
if i < self.symbols.len() - 1 {
changes.push(((second.c, self.symbols[i + 1].c), -1));
changes.push(((replacement, self.symbols[i + 1].c), 1));
if self.symbols[i + 1].len + new_s.len < max_length {
changes.push(((replacement, self.symbols[i + 1].c), 1));
}
}
}
@ -276,7 +287,7 @@ mod tests {
// We're going to perform a merge on the pair ('l', 'l') ~= (2, 2). Let's
// say that 'll' has the ID of 4 in the updated word-to-id vocab.
let changes = word.merge(2, 2, 4);
let changes = word.merge(2, 2, 4, usize::MAX);
// So the word should now look like this:
assert_eq!(
@ -306,4 +317,39 @@ mod tests {
]
);
}
#[test]
fn test_merge_max_length() {
// Let's say we have the word 'hello' and a word-to-id vocab that looks
// like this: {'h': 0, 'e': 1, 'l': 2, 'o': 3}.
let mut word = Word::new();
word.add(0, 1); // 'h'
word.add(1, 1); // 'e'
word.add(2, 1); // 'l'
word.add(2, 1); // 'l'
word.add(3, 1); // 'o'
// We're going to perform a merge on the pair ('l', 'l') ~= (2, 2). Let's
// say that 'll' has the ID of 4 in the updated word-to-id vocab.
let changes = word.merge(2, 2, 4, 2);
assert_eq!(
word.get_chars(),
&[
0u32, // 'h'
1u32, // 'e'
4u32, // 'll'
3u32, // 'o'
]
);
assert_eq!(
changes,
&[
((1u32, 2u32), -1i32), // count for ('e', 'l') should be decreased by 1.
// ((1u32, 4u32), 1i32), Missing since this would be larger than 2
((2u32, 3u32), -1i32), // count for ('l', 'o') should be decreased by 1.
// ((4u32, 3u32), 1i32), Missing since this would be larger than 2
]
);
}
}