Convert word counts to u64 (#1433)

* Convert word counts to u64

* More spots needed to compile
This commit is contained in:
Stephen Roller
2024-02-05 21:39:12 -05:00
committed by GitHub
parent 67fe59c88d
commit 4a8105c366
4 changed files with 35 additions and 35 deletions

View File

@ -183,12 +183,12 @@ impl PyBpeTrainer {
} }
#[getter] #[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 { fn get_min_frequency(self_: PyRef<Self>) -> u64 {
getter!(self_, BpeTrainer, min_frequency) getter!(self_, BpeTrainer, min_frequency)
} }
#[setter] #[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) { fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
setter!(self_, BpeTrainer, min_frequency, freq); setter!(self_, BpeTrainer, min_frequency, freq);
} }
@ -397,12 +397,12 @@ impl PyWordPieceTrainer {
} }
#[getter] #[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 { fn get_min_frequency(self_: PyRef<Self>) -> u64 {
getter!(self_, WordPieceTrainer, min_frequency()) getter!(self_, WordPieceTrainer, min_frequency())
} }
#[setter] #[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) { fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
setter!(self_, WordPieceTrainer, @set_min_frequency, freq); setter!(self_, WordPieceTrainer, @set_min_frequency, freq);
} }
@ -589,12 +589,12 @@ impl PyWordLevelTrainer {
} }
#[getter] #[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 { fn get_min_frequency(self_: PyRef<Self>) -> u64 {
getter!(self_, WordLevelTrainer, min_frequency) getter!(self_, WordLevelTrainer, min_frequency)
} }
#[setter] #[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) { fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
setter!(self_, WordLevelTrainer, min_frequency, freq); setter!(self_, WordLevelTrainer, min_frequency, freq);
} }

View File

@ -11,7 +11,7 @@ use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Debug, Eq)] #[derive(Debug, Eq)]
struct Merge { struct Merge {
pair: Pair, pair: Pair,
count: u32, count: u64,
pos: HashSet<usize>, pos: HashSet<usize>,
} }
impl PartialEq for Merge { impl PartialEq for Merge {
@ -36,7 +36,7 @@ impl Ord for Merge {
} }
struct Config { struct Config {
min_frequency: u32, min_frequency: u64,
vocab_size: usize, vocab_size: usize,
show_progress: bool, show_progress: bool,
special_tokens: Vec<AddedToken>, special_tokens: Vec<AddedToken>,
@ -79,7 +79,7 @@ impl BpeTrainerBuilder {
/// Set the expected minimum frequency /// Set the expected minimum frequency
#[must_use] #[must_use]
pub fn min_frequency(mut self, frequency: u32) -> Self { pub fn min_frequency(mut self, frequency: u64) -> Self {
self.config.min_frequency = frequency; self.config.min_frequency = frequency;
self self
} }
@ -176,7 +176,7 @@ impl BpeTrainerBuilder {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
pub struct BpeTrainer { pub struct BpeTrainer {
/// The minimum frequency a pair must have to produce a merge operation /// The minimum frequency a pair must have to produce a merge operation
pub min_frequency: u32, pub min_frequency: u64,
/// The target vocabulary size /// The target vocabulary size
pub vocab_size: usize, pub vocab_size: usize,
/// Whether to show progress while training /// Whether to show progress while training
@ -195,7 +195,7 @@ pub struct BpeTrainer {
/// An optional parameter to limit the max length of any single token /// An optional parameter to limit the max length of any single token
pub max_token_length: Option<usize>, pub max_token_length: Option<usize>,
words: HashMap<String, u32>, words: HashMap<String, u64>,
} }
impl Default for BpeTrainer { impl Default for BpeTrainer {
@ -205,7 +205,7 @@ impl Default for BpeTrainer {
} }
impl BpeTrainer { impl BpeTrainer {
pub fn new(min_frequency: u32, vocab_size: usize) -> Self { pub fn new(min_frequency: u64, vocab_size: usize) -> Self {
Self { Self {
min_frequency, min_frequency,
vocab_size, vocab_size,
@ -263,7 +263,7 @@ impl BpeTrainer {
/// Compute the initial alphabet and limit it if relevant /// Compute the initial alphabet and limit it if relevant
fn compute_alphabet( fn compute_alphabet(
&self, &self,
wc: &HashMap<String, u32>, wc: &HashMap<String, u64>,
w2id: &mut HashMap<String, u32>, w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>, id2w: &mut Vec<String>,
) { ) {
@ -322,13 +322,13 @@ impl BpeTrainer {
/// Tokenize words and add subwords to the vocabulary when relevant /// Tokenize words and add subwords to the vocabulary when relevant
fn tokenize_words( fn tokenize_words(
&self, &self,
wc: &HashMap<String, u32>, wc: &HashMap<String, u64>,
w2id: &mut HashMap<String, u32>, w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>, id2w: &mut Vec<String>,
p: &Option<ProgressBar>, p: &Option<ProgressBar>,
) -> (Vec<Word>, Vec<u32>) { ) -> (Vec<Word>, Vec<u64>) {
let mut words: Vec<Word> = Vec::with_capacity(wc.len()); let mut words: Vec<Word> = Vec::with_capacity(wc.len());
let mut counts: Vec<u32> = Vec::with_capacity(wc.len()); let mut counts: Vec<u64> = Vec::with_capacity(wc.len());
for (word, count) in wc { for (word, count) in wc {
let mut current_word = Word::new(); let mut current_word = Word::new();
@ -373,7 +373,7 @@ impl BpeTrainer {
fn count_pairs( fn count_pairs(
&self, &self,
words: &[Word], words: &[Word],
counts: &[u32], counts: &[u64],
p: &Option<ProgressBar>, p: &Option<ProgressBar>,
) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) { ) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
words words
@ -431,7 +431,7 @@ impl BpeTrainer {
pub fn do_train( pub fn do_train(
&self, &self,
word_counts: &HashMap<String, u32>, word_counts: &HashMap<String, u64>,
model: &mut BPE, model: &mut BPE,
) -> Result<Vec<AddedToken>> { ) -> Result<Vec<AddedToken>> {
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size); let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
@ -470,7 +470,7 @@ impl BpeTrainer {
if count > 0 { if count > 0 {
queue.push(Merge { queue.push(Merge {
pair, pair,
count: count as u32, count: count as u64,
pos, pos,
}); });
} }
@ -493,8 +493,8 @@ impl BpeTrainer {
} }
let mut top = queue.pop().unwrap(); let mut top = queue.pop().unwrap();
if top.count != pair_counts[&top.pair] as u32 { if top.count != pair_counts[&top.pair] as u64 {
top.count = pair_counts[&top.pair] as u32; top.count = pair_counts[&top.pair] as u64;
queue.push(top); queue.push(top);
continue; continue;
} }
@ -573,7 +573,7 @@ impl BpeTrainer {
if count > 0 { if count > 0 {
queue.push(Merge { queue.push(Merge {
pair, pair,
count: count as u32, count: count as u64,
pos, pos,
}); });
} }
@ -632,7 +632,7 @@ impl Trainer for BpeTrainer {
S: AsRef<str> + Send, S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync, F: Fn(&str) -> Result<Vec<String>> + Sync,
{ {
let words: Result<HashMap<String, u32>> = iterator let words: Result<HashMap<String, u64>> = iterator
.maybe_par_bridge() .maybe_par_bridge()
.map(|sequence| { .map(|sequence| {
let words = process(sequence.as_ref())?; let words = process(sequence.as_ref())?;
@ -665,7 +665,7 @@ mod tests {
#[test] #[test]
fn test_train() { fn test_train() {
let word_counts: HashMap<String, u32> = [ let word_counts: HashMap<String, u64> = [
("roses".into(), 1), ("roses".into(), 1),
("are".into(), 2), ("are".into(), 2),
("red".into(), 1), ("red".into(), 1),
@ -744,7 +744,7 @@ mod tests {
*/ */
let max_token_length = 16; let max_token_length = 16;
let long_word_counts: HashMap<String, u32> = [ let long_word_counts: HashMap<String, u64> = [
("singlelongtokenwithoutcasechange", 2), ("singlelongtokenwithoutcasechange", 2),
("singleLongTokenWithCamelCaseChange", 2), ("singleLongTokenWithCamelCaseChange", 2),
("Longsingletokenwithpunctu@t!onwithin", 2), ("Longsingletokenwithpunctu@t!onwithin", 2),
@ -784,7 +784,7 @@ mod tests {
// directly compares tokens with known expected values. // directly compares tokens with known expected values.
// maybe unstable depending on specific settings or changes. // maybe unstable depending on specific settings or changes.
*/ */
let long_word_counts: HashMap<String, u32> = [ let long_word_counts: HashMap<String, u64> = [
("sin", 2), ("sin", 2),
("Sin", 2), ("Sin", 2),
("Lon", 2), ("Lon", 2),

View File

@ -10,7 +10,7 @@ use std::collections::HashMap;
pub struct WordLevelTrainer { pub struct WordLevelTrainer {
/// The minimum frequency a word must have to be part of the vocabulary /// The minimum frequency a word must have to be part of the vocabulary
#[builder(default = "0")] #[builder(default = "0")]
pub min_frequency: u32, pub min_frequency: u64,
/// The target vocabulary size /// The target vocabulary size
#[builder(default = "30_000")] #[builder(default = "30_000")]
pub vocab_size: usize, pub vocab_size: usize,
@ -22,7 +22,7 @@ pub struct WordLevelTrainer {
pub special_tokens: Vec<AddedToken>, pub special_tokens: Vec<AddedToken>,
#[builder(default, private)] #[builder(default, private)]
words: HashMap<String, u32>, words: HashMap<String, u64>,
} }
impl Default for WordLevelTrainer { impl Default for WordLevelTrainer {
@ -38,14 +38,14 @@ impl WordLevelTrainer {
fn do_train( fn do_train(
&self, &self,
word_counts: &HashMap<String, u32>, word_counts: &HashMap<String, u64>,
model: &mut WordLevel, model: &mut WordLevel,
) -> Result<Vec<AddedToken>> { ) -> Result<Vec<AddedToken>> {
let mut ordered_counts = word_counts.iter().collect::<Vec<_>>(); let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();
//sort the word counts first by inverse counts and then by word, in order //sort the word counts first by inverse counts and then by word, in order
//to keep the sorting deterministic in case of equal counts //to keep the sorting deterministic in case of equal counts
let cmp = |l: &(&String, &u32), r: &(&String, &u32)| -> Ordering { let cmp = |l: &(&String, &u64), r: &(&String, &u64)| -> Ordering {
let count_comp: Ordering = l.1.cmp(r.1); let count_comp: Ordering = l.1.cmp(r.1);
if count_comp != Ordering::Equal { if count_comp != Ordering::Equal {
return count_comp.reverse(); return count_comp.reverse();
@ -100,7 +100,7 @@ impl Trainer for WordLevelTrainer {
S: AsRef<str> + Send, S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync, F: Fn(&str) -> Result<Vec<String>> + Sync,
{ {
let words: Result<HashMap<String, u32>> = iterator let words: Result<HashMap<String, u64>> = iterator
.maybe_par_bridge() .maybe_par_bridge()
.map(|sequence| { .map(|sequence| {
let words = process(sequence.as_ref())?; let words = process(sequence.as_ref())?;
@ -132,7 +132,7 @@ mod tests {
#[test] #[test]
fn test_train() { fn test_train() {
let word_counts: HashMap<String, u32> = [ let word_counts: HashMap<String, u64> = [
("the".into(), 25), ("the".into(), 25),
("roses".into(), 22), ("roses".into(), 22),
("are".into(), 24), ("are".into(), 24),

View File

@ -26,7 +26,7 @@ impl WordPieceTrainerBuilder {
/// Set the expected minimum frequency /// Set the expected minimum frequency
#[must_use] #[must_use]
pub fn min_frequency(mut self, frequency: u32) -> Self { pub fn min_frequency(mut self, frequency: u64) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency); self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
self self
} }
@ -94,11 +94,11 @@ pub struct WordPieceTrainer {
} }
impl WordPieceTrainer { impl WordPieceTrainer {
pub fn min_frequency(&self) -> u32 { pub fn min_frequency(&self) -> u64 {
self.bpe_trainer.min_frequency self.bpe_trainer.min_frequency
} }
pub fn set_min_frequency(&mut self, freq: u32) { pub fn set_min_frequency(&mut self, freq: u64) {
self.bpe_trainer.min_frequency = freq; self.bpe_trainer.min_frequency = freq;
} }