diff --git a/tokenizers/benches/bert_benchmark.rs b/tokenizers/benches/bert_benchmark.rs index cea1485b..af5610d3 100644 --- a/tokenizers/benches/bert_benchmark.rs +++ b/tokenizers/benches/bert_benchmark.rs @@ -70,7 +70,7 @@ pub fn bench_bert(c: &mut Criterion) { } fn bench_train(c: &mut Criterion) { - let trainer = WordPieceTrainerBuilder::default() + let mut trainer = WordPieceTrainerBuilder::default() .show_progress(false) .build(); type Tok = TokenizerImpl< @@ -87,7 +87,7 @@ fn bench_train(c: &mut Criterion) { iter_bench_train( iters, &mut tokenizer, - &trainer, + &mut trainer, vec!["data/small.txt".to_string()], ) }) @@ -100,7 +100,7 @@ fn bench_train(c: &mut Criterion) { iter_bench_train( iters, &mut tokenizer, - &trainer, + &mut trainer, vec!["data/big.txt".to_string()], ) }) diff --git a/tokenizers/benches/bpe_benchmark.rs b/tokenizers/benches/bpe_benchmark.rs index 2e99b378..cf86dfef 100644 --- a/tokenizers/benches/bpe_benchmark.rs +++ b/tokenizers/benches/bpe_benchmark.rs @@ -69,7 +69,7 @@ fn bench_gpt2(c: &mut Criterion) { } fn bench_train(c: &mut Criterion) { - let trainer: TrainerWrapper = BpeTrainerBuilder::default() + let mut trainer: TrainerWrapper = BpeTrainerBuilder::default() .show_progress(false) .build() .into(); @@ -80,7 +80,7 @@ fn bench_train(c: &mut Criterion) { iter_bench_train( iters, &mut tokenizer, - &trainer, + &mut trainer, vec!["data/small.txt".to_string()], ) }) @@ -93,7 +93,7 @@ fn bench_train(c: &mut Criterion) { iter_bench_train( iters, &mut tokenizer, - &trainer, + &mut trainer, vec!["data/big.txt".to_string()], ) }) diff --git a/tokenizers/benches/common/mod.rs b/tokenizers/benches/common/mod.rs index 1c2453df..7f4cb933 100644 --- a/tokenizers/benches/common/mod.rs +++ b/tokenizers/benches/common/mod.rs @@ -61,7 +61,7 @@ where pub fn iter_bench_train( iters: u64, tokenizer: &mut TokenizerImpl, - trainer: &T, + trainer: &mut T, files: Vec, ) -> Duration where @@ -75,7 +75,7 @@ where let mut duration = Duration::new(0, 0); for _i in 0..iters { let start = Instant::now(); - tokenizer.train(trainer, files.clone()).unwrap(); + tokenizer.train_from_files(trainer, files.clone()).unwrap(); duration = duration.checked_add(start.elapsed()).unwrap(); } duration