add better single threaded GPT2 benchmark

This commit is contained in:
epwalsh
2020-01-01 15:48:53 -08:00
parent 722b61230d
commit b09511f5cf

View File

@ -1,11 +1,11 @@
#[macro_use]
extern crate criterion;
use criterion::{BatchSize, Criterion};
use criterion::{black_box, BatchSize, Criterion};
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::io::{self, BufRead, BufReader};
use std::path::Path;
use std::time::Duration;
use std::time::{Duration, Instant};
use tokenizers::models::bpe::BPE;
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::tokenizer::{AddedToken, EncodeInput, Tokenizer};
@ -50,6 +50,10 @@ fn create_gpt2_tokenizer(bpe: &BPE) -> Tokenizer {
tokenizer
}
fn line_to_input(line: io::Result<String>) -> EncodeInput {
EncodeInput::Single(line.unwrap())
}
fn bench_gpt2_encode(c: &mut Criterion) {
let bpe = BPE::from_files("benches/gpt2-vocab.json", "benches/gpt2-merges.txt").unwrap();
@ -65,13 +69,39 @@ fn bench_gpt2_encode(c: &mut Criterion) {
BatchSize::LargeInput,
)
});
c.bench_function("BPE GPT2 encode many", |b| {
b.iter_custom(|iters| {
let tokenizer = create_gpt2_tokenizer(&bpe);
let mut lines = BufReader::new(File::open(Path::new("benches/big.txt")).unwrap())
.lines()
.map(line_to_input);
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let input = match lines.next() {
Some(line) => line,
None => {
// Reset the lines iterator.
lines = BufReader::new(File::open(Path::new("benches/big.txt")).unwrap())
.lines()
.map(line_to_input);
lines.next().unwrap()
}
};
let start = Instant::now();
let _ = black_box(tokenizer.encode(input));
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
})
});
}
fn bench_gpt2_encode_batch(c: &mut Criterion) {
let bpe = BPE::from_files("benches/gpt2-vocab.json", "benches/gpt2-merges.txt").unwrap();
let lines: Vec<EncodeInput> = BufReader::new(File::open(Path::new("benches/big.txt")).unwrap())
.lines()
.map(|l| EncodeInput::Single(l.unwrap()))
.map(line_to_input)
.collect();
c.bench_function("BPE GPT2 encode batch", |b| {