mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Convert word counts to u64 (#1433)
* Convert word counts to u64 * More spots needed to compile
This commit is contained in:
@ -183,12 +183,12 @@ impl PyBpeTrainer {
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
|
||||
getter!(self_, BpeTrainer, min_frequency)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
|
||||
setter!(self_, BpeTrainer, min_frequency, freq);
|
||||
}
|
||||
|
||||
@ -397,12 +397,12 @@ impl PyWordPieceTrainer {
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
|
||||
getter!(self_, WordPieceTrainer, min_frequency())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
|
||||
setter!(self_, WordPieceTrainer, @set_min_frequency, freq);
|
||||
}
|
||||
|
||||
@ -589,12 +589,12 @@ impl PyWordLevelTrainer {
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
||||
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
|
||||
getter!(self_, WordLevelTrainer, min_frequency)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
||||
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
|
||||
setter!(self_, WordLevelTrainer, min_frequency, freq);
|
||||
}
|
||||
|
||||
|
@ -11,7 +11,7 @@ use std::collections::{BinaryHeap, HashMap, HashSet};
|
||||
#[derive(Debug, Eq)]
|
||||
struct Merge {
|
||||
pair: Pair,
|
||||
count: u32,
|
||||
count: u64,
|
||||
pos: HashSet<usize>,
|
||||
}
|
||||
impl PartialEq for Merge {
|
||||
@ -36,7 +36,7 @@ impl Ord for Merge {
|
||||
}
|
||||
|
||||
struct Config {
|
||||
min_frequency: u32,
|
||||
min_frequency: u64,
|
||||
vocab_size: usize,
|
||||
show_progress: bool,
|
||||
special_tokens: Vec<AddedToken>,
|
||||
@ -79,7 +79,7 @@ impl BpeTrainerBuilder {
|
||||
|
||||
/// Set the expected minimum frequency
|
||||
#[must_use]
|
||||
pub fn min_frequency(mut self, frequency: u32) -> Self {
|
||||
pub fn min_frequency(mut self, frequency: u64) -> Self {
|
||||
self.config.min_frequency = frequency;
|
||||
self
|
||||
}
|
||||
@ -176,7 +176,7 @@ impl BpeTrainerBuilder {
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
|
||||
pub struct BpeTrainer {
|
||||
/// The minimum frequency a pair must have to produce a merge operation
|
||||
pub min_frequency: u32,
|
||||
pub min_frequency: u64,
|
||||
/// The target vocabulary size
|
||||
pub vocab_size: usize,
|
||||
/// Whether to show progress while training
|
||||
@ -195,7 +195,7 @@ pub struct BpeTrainer {
|
||||
/// An optional parameter to limit the max length of any single token
|
||||
pub max_token_length: Option<usize>,
|
||||
|
||||
words: HashMap<String, u32>,
|
||||
words: HashMap<String, u64>,
|
||||
}
|
||||
|
||||
impl Default for BpeTrainer {
|
||||
@ -205,7 +205,7 @@ impl Default for BpeTrainer {
|
||||
}
|
||||
|
||||
impl BpeTrainer {
|
||||
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
|
||||
pub fn new(min_frequency: u64, vocab_size: usize) -> Self {
|
||||
Self {
|
||||
min_frequency,
|
||||
vocab_size,
|
||||
@ -263,7 +263,7 @@ impl BpeTrainer {
|
||||
/// Compute the initial alphabet and limit it if relevant
|
||||
fn compute_alphabet(
|
||||
&self,
|
||||
wc: &HashMap<String, u32>,
|
||||
wc: &HashMap<String, u64>,
|
||||
w2id: &mut HashMap<String, u32>,
|
||||
id2w: &mut Vec<String>,
|
||||
) {
|
||||
@ -322,13 +322,13 @@ impl BpeTrainer {
|
||||
/// Tokenize words and add subwords to the vocabulary when relevant
|
||||
fn tokenize_words(
|
||||
&self,
|
||||
wc: &HashMap<String, u32>,
|
||||
wc: &HashMap<String, u64>,
|
||||
w2id: &mut HashMap<String, u32>,
|
||||
id2w: &mut Vec<String>,
|
||||
p: &Option<ProgressBar>,
|
||||
) -> (Vec<Word>, Vec<u32>) {
|
||||
) -> (Vec<Word>, Vec<u64>) {
|
||||
let mut words: Vec<Word> = Vec::with_capacity(wc.len());
|
||||
let mut counts: Vec<u32> = Vec::with_capacity(wc.len());
|
||||
let mut counts: Vec<u64> = Vec::with_capacity(wc.len());
|
||||
|
||||
for (word, count) in wc {
|
||||
let mut current_word = Word::new();
|
||||
@ -373,7 +373,7 @@ impl BpeTrainer {
|
||||
fn count_pairs(
|
||||
&self,
|
||||
words: &[Word],
|
||||
counts: &[u32],
|
||||
counts: &[u64],
|
||||
p: &Option<ProgressBar>,
|
||||
) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
|
||||
words
|
||||
@ -431,7 +431,7 @@ impl BpeTrainer {
|
||||
|
||||
pub fn do_train(
|
||||
&self,
|
||||
word_counts: &HashMap<String, u32>,
|
||||
word_counts: &HashMap<String, u64>,
|
||||
model: &mut BPE,
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
|
||||
@ -470,7 +470,7 @@ impl BpeTrainer {
|
||||
if count > 0 {
|
||||
queue.push(Merge {
|
||||
pair,
|
||||
count: count as u32,
|
||||
count: count as u64,
|
||||
pos,
|
||||
});
|
||||
}
|
||||
@ -493,8 +493,8 @@ impl BpeTrainer {
|
||||
}
|
||||
|
||||
let mut top = queue.pop().unwrap();
|
||||
if top.count != pair_counts[&top.pair] as u32 {
|
||||
top.count = pair_counts[&top.pair] as u32;
|
||||
if top.count != pair_counts[&top.pair] as u64 {
|
||||
top.count = pair_counts[&top.pair] as u64;
|
||||
queue.push(top);
|
||||
continue;
|
||||
}
|
||||
@ -573,7 +573,7 @@ impl BpeTrainer {
|
||||
if count > 0 {
|
||||
queue.push(Merge {
|
||||
pair,
|
||||
count: count as u32,
|
||||
count: count as u64,
|
||||
pos,
|
||||
});
|
||||
}
|
||||
@ -632,7 +632,7 @@ impl Trainer for BpeTrainer {
|
||||
S: AsRef<str> + Send,
|
||||
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||
{
|
||||
let words: Result<HashMap<String, u32>> = iterator
|
||||
let words: Result<HashMap<String, u64>> = iterator
|
||||
.maybe_par_bridge()
|
||||
.map(|sequence| {
|
||||
let words = process(sequence.as_ref())?;
|
||||
@ -665,7 +665,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_train() {
|
||||
let word_counts: HashMap<String, u32> = [
|
||||
let word_counts: HashMap<String, u64> = [
|
||||
("roses".into(), 1),
|
||||
("are".into(), 2),
|
||||
("red".into(), 1),
|
||||
@ -744,7 +744,7 @@ mod tests {
|
||||
*/
|
||||
|
||||
let max_token_length = 16;
|
||||
let long_word_counts: HashMap<String, u32> = [
|
||||
let long_word_counts: HashMap<String, u64> = [
|
||||
("singlelongtokenwithoutcasechange", 2),
|
||||
("singleLongTokenWithCamelCaseChange", 2),
|
||||
("Longsingletokenwithpunctu@t!onwithin", 2),
|
||||
@ -784,7 +784,7 @@ mod tests {
|
||||
// directly compares tokens with known expected values.
|
||||
// maybe unstable depending on specific settings or changes.
|
||||
*/
|
||||
let long_word_counts: HashMap<String, u32> = [
|
||||
let long_word_counts: HashMap<String, u64> = [
|
||||
("sin", 2),
|
||||
("Sin", 2),
|
||||
("Lon", 2),
|
||||
|
@ -10,7 +10,7 @@ use std::collections::HashMap;
|
||||
pub struct WordLevelTrainer {
|
||||
/// The minimum frequency a word must have to be part of the vocabulary
|
||||
#[builder(default = "0")]
|
||||
pub min_frequency: u32,
|
||||
pub min_frequency: u64,
|
||||
/// The target vocabulary size
|
||||
#[builder(default = "30_000")]
|
||||
pub vocab_size: usize,
|
||||
@ -22,7 +22,7 @@ pub struct WordLevelTrainer {
|
||||
pub special_tokens: Vec<AddedToken>,
|
||||
|
||||
#[builder(default, private)]
|
||||
words: HashMap<String, u32>,
|
||||
words: HashMap<String, u64>,
|
||||
}
|
||||
|
||||
impl Default for WordLevelTrainer {
|
||||
@ -38,14 +38,14 @@ impl WordLevelTrainer {
|
||||
|
||||
fn do_train(
|
||||
&self,
|
||||
word_counts: &HashMap<String, u32>,
|
||||
word_counts: &HashMap<String, u64>,
|
||||
model: &mut WordLevel,
|
||||
) -> Result<Vec<AddedToken>> {
|
||||
let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();
|
||||
|
||||
//sort the word counts first by inverse counts and then by word, in order
|
||||
//to keep the sorting deterministic in case of equal counts
|
||||
let cmp = |l: &(&String, &u32), r: &(&String, &u32)| -> Ordering {
|
||||
let cmp = |l: &(&String, &u64), r: &(&String, &u64)| -> Ordering {
|
||||
let count_comp: Ordering = l.1.cmp(r.1);
|
||||
if count_comp != Ordering::Equal {
|
||||
return count_comp.reverse();
|
||||
@ -100,7 +100,7 @@ impl Trainer for WordLevelTrainer {
|
||||
S: AsRef<str> + Send,
|
||||
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||
{
|
||||
let words: Result<HashMap<String, u32>> = iterator
|
||||
let words: Result<HashMap<String, u64>> = iterator
|
||||
.maybe_par_bridge()
|
||||
.map(|sequence| {
|
||||
let words = process(sequence.as_ref())?;
|
||||
@ -132,7 +132,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_train() {
|
||||
let word_counts: HashMap<String, u32> = [
|
||||
let word_counts: HashMap<String, u64> = [
|
||||
("the".into(), 25),
|
||||
("roses".into(), 22),
|
||||
("are".into(), 24),
|
||||
|
@ -26,7 +26,7 @@ impl WordPieceTrainerBuilder {
|
||||
|
||||
/// Set the expected minimum frequency
|
||||
#[must_use]
|
||||
pub fn min_frequency(mut self, frequency: u32) -> Self {
|
||||
pub fn min_frequency(mut self, frequency: u64) -> Self {
|
||||
self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
|
||||
self
|
||||
}
|
||||
@ -94,11 +94,11 @@ pub struct WordPieceTrainer {
|
||||
}
|
||||
|
||||
impl WordPieceTrainer {
|
||||
pub fn min_frequency(&self) -> u32 {
|
||||
pub fn min_frequency(&self) -> u64 {
|
||||
self.bpe_trainer.min_frequency
|
||||
}
|
||||
|
||||
pub fn set_min_frequency(&mut self, freq: u32) {
|
||||
pub fn set_min_frequency(&mut self, freq: u64) {
|
||||
self.bpe_trainer.min_frequency = freq;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user