mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Convert word counts to u64 (#1433)
* Convert word counts to u64 * More spots needed to compile
This commit is contained in:
@ -183,12 +183,12 @@ impl PyBpeTrainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
|
||||||
getter!(self_, BpeTrainer, min_frequency)
|
getter!(self_, BpeTrainer, min_frequency)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[setter]
|
#[setter]
|
||||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
|
||||||
setter!(self_, BpeTrainer, min_frequency, freq);
|
setter!(self_, BpeTrainer, min_frequency, freq);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -397,12 +397,12 @@ impl PyWordPieceTrainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
|
||||||
getter!(self_, WordPieceTrainer, min_frequency())
|
getter!(self_, WordPieceTrainer, min_frequency())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[setter]
|
#[setter]
|
||||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
|
||||||
setter!(self_, WordPieceTrainer, @set_min_frequency, freq);
|
setter!(self_, WordPieceTrainer, @set_min_frequency, freq);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -589,12 +589,12 @@ impl PyWordLevelTrainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
|
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
|
||||||
getter!(self_, WordLevelTrainer, min_frequency)
|
getter!(self_, WordLevelTrainer, min_frequency)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[setter]
|
#[setter]
|
||||||
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
|
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
|
||||||
setter!(self_, WordLevelTrainer, min_frequency, freq);
|
setter!(self_, WordLevelTrainer, min_frequency, freq);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ use std::collections::{BinaryHeap, HashMap, HashSet};
|
|||||||
#[derive(Debug, Eq)]
|
#[derive(Debug, Eq)]
|
||||||
struct Merge {
|
struct Merge {
|
||||||
pair: Pair,
|
pair: Pair,
|
||||||
count: u32,
|
count: u64,
|
||||||
pos: HashSet<usize>,
|
pos: HashSet<usize>,
|
||||||
}
|
}
|
||||||
impl PartialEq for Merge {
|
impl PartialEq for Merge {
|
||||||
@ -36,7 +36,7 @@ impl Ord for Merge {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Config {
|
struct Config {
|
||||||
min_frequency: u32,
|
min_frequency: u64,
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
show_progress: bool,
|
show_progress: bool,
|
||||||
special_tokens: Vec<AddedToken>,
|
special_tokens: Vec<AddedToken>,
|
||||||
@ -79,7 +79,7 @@ impl BpeTrainerBuilder {
|
|||||||
|
|
||||||
/// Set the expected minimum frequency
|
/// Set the expected minimum frequency
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn min_frequency(mut self, frequency: u32) -> Self {
|
pub fn min_frequency(mut self, frequency: u64) -> Self {
|
||||||
self.config.min_frequency = frequency;
|
self.config.min_frequency = frequency;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@ -176,7 +176,7 @@ impl BpeTrainerBuilder {
|
|||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
|
||||||
pub struct BpeTrainer {
|
pub struct BpeTrainer {
|
||||||
/// The minimum frequency a pair must have to produce a merge operation
|
/// The minimum frequency a pair must have to produce a merge operation
|
||||||
pub min_frequency: u32,
|
pub min_frequency: u64,
|
||||||
/// The target vocabulary size
|
/// The target vocabulary size
|
||||||
pub vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
/// Whether to show progress while training
|
/// Whether to show progress while training
|
||||||
@ -195,7 +195,7 @@ pub struct BpeTrainer {
|
|||||||
/// An optional parameter to limit the max length of any single token
|
/// An optional parameter to limit the max length of any single token
|
||||||
pub max_token_length: Option<usize>,
|
pub max_token_length: Option<usize>,
|
||||||
|
|
||||||
words: HashMap<String, u32>,
|
words: HashMap<String, u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for BpeTrainer {
|
impl Default for BpeTrainer {
|
||||||
@ -205,7 +205,7 @@ impl Default for BpeTrainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BpeTrainer {
|
impl BpeTrainer {
|
||||||
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
|
pub fn new(min_frequency: u64, vocab_size: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
min_frequency,
|
min_frequency,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
@ -263,7 +263,7 @@ impl BpeTrainer {
|
|||||||
/// Compute the initial alphabet and limit it if relevant
|
/// Compute the initial alphabet and limit it if relevant
|
||||||
fn compute_alphabet(
|
fn compute_alphabet(
|
||||||
&self,
|
&self,
|
||||||
wc: &HashMap<String, u32>,
|
wc: &HashMap<String, u64>,
|
||||||
w2id: &mut HashMap<String, u32>,
|
w2id: &mut HashMap<String, u32>,
|
||||||
id2w: &mut Vec<String>,
|
id2w: &mut Vec<String>,
|
||||||
) {
|
) {
|
||||||
@ -322,13 +322,13 @@ impl BpeTrainer {
|
|||||||
/// Tokenize words and add subwords to the vocabulary when relevant
|
/// Tokenize words and add subwords to the vocabulary when relevant
|
||||||
fn tokenize_words(
|
fn tokenize_words(
|
||||||
&self,
|
&self,
|
||||||
wc: &HashMap<String, u32>,
|
wc: &HashMap<String, u64>,
|
||||||
w2id: &mut HashMap<String, u32>,
|
w2id: &mut HashMap<String, u32>,
|
||||||
id2w: &mut Vec<String>,
|
id2w: &mut Vec<String>,
|
||||||
p: &Option<ProgressBar>,
|
p: &Option<ProgressBar>,
|
||||||
) -> (Vec<Word>, Vec<u32>) {
|
) -> (Vec<Word>, Vec<u64>) {
|
||||||
let mut words: Vec<Word> = Vec::with_capacity(wc.len());
|
let mut words: Vec<Word> = Vec::with_capacity(wc.len());
|
||||||
let mut counts: Vec<u32> = Vec::with_capacity(wc.len());
|
let mut counts: Vec<u64> = Vec::with_capacity(wc.len());
|
||||||
|
|
||||||
for (word, count) in wc {
|
for (word, count) in wc {
|
||||||
let mut current_word = Word::new();
|
let mut current_word = Word::new();
|
||||||
@ -373,7 +373,7 @@ impl BpeTrainer {
|
|||||||
fn count_pairs(
|
fn count_pairs(
|
||||||
&self,
|
&self,
|
||||||
words: &[Word],
|
words: &[Word],
|
||||||
counts: &[u32],
|
counts: &[u64],
|
||||||
p: &Option<ProgressBar>,
|
p: &Option<ProgressBar>,
|
||||||
) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
|
) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
|
||||||
words
|
words
|
||||||
@ -431,7 +431,7 @@ impl BpeTrainer {
|
|||||||
|
|
||||||
pub fn do_train(
|
pub fn do_train(
|
||||||
&self,
|
&self,
|
||||||
word_counts: &HashMap<String, u32>,
|
word_counts: &HashMap<String, u64>,
|
||||||
model: &mut BPE,
|
model: &mut BPE,
|
||||||
) -> Result<Vec<AddedToken>> {
|
) -> Result<Vec<AddedToken>> {
|
||||||
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
|
let mut word_to_id: HashMap<String, u32> = HashMap::with_capacity(self.vocab_size);
|
||||||
@ -470,7 +470,7 @@ impl BpeTrainer {
|
|||||||
if count > 0 {
|
if count > 0 {
|
||||||
queue.push(Merge {
|
queue.push(Merge {
|
||||||
pair,
|
pair,
|
||||||
count: count as u32,
|
count: count as u64,
|
||||||
pos,
|
pos,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -493,8 +493,8 @@ impl BpeTrainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut top = queue.pop().unwrap();
|
let mut top = queue.pop().unwrap();
|
||||||
if top.count != pair_counts[&top.pair] as u32 {
|
if top.count != pair_counts[&top.pair] as u64 {
|
||||||
top.count = pair_counts[&top.pair] as u32;
|
top.count = pair_counts[&top.pair] as u64;
|
||||||
queue.push(top);
|
queue.push(top);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -573,7 +573,7 @@ impl BpeTrainer {
|
|||||||
if count > 0 {
|
if count > 0 {
|
||||||
queue.push(Merge {
|
queue.push(Merge {
|
||||||
pair,
|
pair,
|
||||||
count: count as u32,
|
count: count as u64,
|
||||||
pos,
|
pos,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -632,7 +632,7 @@ impl Trainer for BpeTrainer {
|
|||||||
S: AsRef<str> + Send,
|
S: AsRef<str> + Send,
|
||||||
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||||
{
|
{
|
||||||
let words: Result<HashMap<String, u32>> = iterator
|
let words: Result<HashMap<String, u64>> = iterator
|
||||||
.maybe_par_bridge()
|
.maybe_par_bridge()
|
||||||
.map(|sequence| {
|
.map(|sequence| {
|
||||||
let words = process(sequence.as_ref())?;
|
let words = process(sequence.as_ref())?;
|
||||||
@ -665,7 +665,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_train() {
|
fn test_train() {
|
||||||
let word_counts: HashMap<String, u32> = [
|
let word_counts: HashMap<String, u64> = [
|
||||||
("roses".into(), 1),
|
("roses".into(), 1),
|
||||||
("are".into(), 2),
|
("are".into(), 2),
|
||||||
("red".into(), 1),
|
("red".into(), 1),
|
||||||
@ -744,7 +744,7 @@ mod tests {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
let max_token_length = 16;
|
let max_token_length = 16;
|
||||||
let long_word_counts: HashMap<String, u32> = [
|
let long_word_counts: HashMap<String, u64> = [
|
||||||
("singlelongtokenwithoutcasechange", 2),
|
("singlelongtokenwithoutcasechange", 2),
|
||||||
("singleLongTokenWithCamelCaseChange", 2),
|
("singleLongTokenWithCamelCaseChange", 2),
|
||||||
("Longsingletokenwithpunctu@t!onwithin", 2),
|
("Longsingletokenwithpunctu@t!onwithin", 2),
|
||||||
@ -784,7 +784,7 @@ mod tests {
|
|||||||
// directly compares tokens with known expected values.
|
// directly compares tokens with known expected values.
|
||||||
// maybe unstable depending on specific settings or changes.
|
// maybe unstable depending on specific settings or changes.
|
||||||
*/
|
*/
|
||||||
let long_word_counts: HashMap<String, u32> = [
|
let long_word_counts: HashMap<String, u64> = [
|
||||||
("sin", 2),
|
("sin", 2),
|
||||||
("Sin", 2),
|
("Sin", 2),
|
||||||
("Lon", 2),
|
("Lon", 2),
|
||||||
|
@ -10,7 +10,7 @@ use std::collections::HashMap;
|
|||||||
pub struct WordLevelTrainer {
|
pub struct WordLevelTrainer {
|
||||||
/// The minimum frequency a word must have to be part of the vocabulary
|
/// The minimum frequency a word must have to be part of the vocabulary
|
||||||
#[builder(default = "0")]
|
#[builder(default = "0")]
|
||||||
pub min_frequency: u32,
|
pub min_frequency: u64,
|
||||||
/// The target vocabulary size
|
/// The target vocabulary size
|
||||||
#[builder(default = "30_000")]
|
#[builder(default = "30_000")]
|
||||||
pub vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
@ -22,7 +22,7 @@ pub struct WordLevelTrainer {
|
|||||||
pub special_tokens: Vec<AddedToken>,
|
pub special_tokens: Vec<AddedToken>,
|
||||||
|
|
||||||
#[builder(default, private)]
|
#[builder(default, private)]
|
||||||
words: HashMap<String, u32>,
|
words: HashMap<String, u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for WordLevelTrainer {
|
impl Default for WordLevelTrainer {
|
||||||
@ -38,14 +38,14 @@ impl WordLevelTrainer {
|
|||||||
|
|
||||||
fn do_train(
|
fn do_train(
|
||||||
&self,
|
&self,
|
||||||
word_counts: &HashMap<String, u32>,
|
word_counts: &HashMap<String, u64>,
|
||||||
model: &mut WordLevel,
|
model: &mut WordLevel,
|
||||||
) -> Result<Vec<AddedToken>> {
|
) -> Result<Vec<AddedToken>> {
|
||||||
let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();
|
let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();
|
||||||
|
|
||||||
//sort the word counts first by inverse counts and then by word, in order
|
//sort the word counts first by inverse counts and then by word, in order
|
||||||
//to keep the sorting deterministic in case of equal counts
|
//to keep the sorting deterministic in case of equal counts
|
||||||
let cmp = |l: &(&String, &u32), r: &(&String, &u32)| -> Ordering {
|
let cmp = |l: &(&String, &u64), r: &(&String, &u64)| -> Ordering {
|
||||||
let count_comp: Ordering = l.1.cmp(r.1);
|
let count_comp: Ordering = l.1.cmp(r.1);
|
||||||
if count_comp != Ordering::Equal {
|
if count_comp != Ordering::Equal {
|
||||||
return count_comp.reverse();
|
return count_comp.reverse();
|
||||||
@ -100,7 +100,7 @@ impl Trainer for WordLevelTrainer {
|
|||||||
S: AsRef<str> + Send,
|
S: AsRef<str> + Send,
|
||||||
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
F: Fn(&str) -> Result<Vec<String>> + Sync,
|
||||||
{
|
{
|
||||||
let words: Result<HashMap<String, u32>> = iterator
|
let words: Result<HashMap<String, u64>> = iterator
|
||||||
.maybe_par_bridge()
|
.maybe_par_bridge()
|
||||||
.map(|sequence| {
|
.map(|sequence| {
|
||||||
let words = process(sequence.as_ref())?;
|
let words = process(sequence.as_ref())?;
|
||||||
@ -132,7 +132,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_train() {
|
fn test_train() {
|
||||||
let word_counts: HashMap<String, u32> = [
|
let word_counts: HashMap<String, u64> = [
|
||||||
("the".into(), 25),
|
("the".into(), 25),
|
||||||
("roses".into(), 22),
|
("roses".into(), 22),
|
||||||
("are".into(), 24),
|
("are".into(), 24),
|
||||||
|
@ -26,7 +26,7 @@ impl WordPieceTrainerBuilder {
|
|||||||
|
|
||||||
/// Set the expected minimum frequency
|
/// Set the expected minimum frequency
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn min_frequency(mut self, frequency: u32) -> Self {
|
pub fn min_frequency(mut self, frequency: u64) -> Self {
|
||||||
self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
|
self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@ -94,11 +94,11 @@ pub struct WordPieceTrainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl WordPieceTrainer {
|
impl WordPieceTrainer {
|
||||||
pub fn min_frequency(&self) -> u32 {
|
pub fn min_frequency(&self) -> u64 {
|
||||||
self.bpe_trainer.min_frequency
|
self.bpe_trainer.min_frequency
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_min_frequency(&mut self, freq: u32) {
|
pub fn set_min_frequency(&mut self, freq: u64) {
|
||||||
self.bpe_trainer.min_frequency = freq;
|
self.bpe_trainer.min_frequency = freq;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user