Perf improvement 16% by removing offsets. (#1587)

* [Breaking Change] Perf improvement 16% by removing offsets.

Offsets calculation are always calculated in Python land.
By changing it to not being calculated, we win 16% of the runtime.

This is not the total extent of it because offsets are
still calculated in bytes.

* Required features.

* Remove clippy error.

* Make it non breaking and still show perf improvement.

* Even faster without offsets.

* Update doc.

* Fmt.

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fmt.

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2024-08-08 14:56:13 +02:00
committed by GitHub
parent bd27fa56d6
commit bfd9cdeefb
8 changed files with 247 additions and 6 deletions

View File

@ -60,7 +60,12 @@ def benchmark_batch(model: str, documents: list[str], num_threads: int, document
mergeable_ranks=mergeable_ranks,
special_tokens=special_tokens,
)
enc.encode("warmup")
out = enc.encode("This is a test")
hf_enc = Tokenizer.from_pretrained(model)
out2 = hf_enc.encode("This is a test", add_special_tokens=False).ids
assert out == out2, "sanity check"
start = time.perf_counter_ns()
enc.encode_ordinary_batch(documents, num_threads=num_threads)
@ -69,11 +74,9 @@ def benchmark_batch(model: str, documents: list[str], num_threads: int, document
readable_size, unit = format_byte_size(num_bytes / (end - start) * 1e9)
print(f"tiktoken \t{readable_size} / s")
hf_enc = Tokenizer.from_pretrained(model)
hf_enc.encode("warmup")
start = time.perf_counter_ns()
hf_enc.encode_batch(documents)
hf_enc.encode_batch_fast(documents)
end = time.perf_counter_ns()
readable_size, unit = format_byte_size(num_bytes / (end - start) * 1e9)
print(f"huggingface \t{readable_size} / s")
@ -82,8 +85,7 @@ def benchmark_batch(model: str, documents: list[str], num_threads: int, document
def test(model: str, dataset: str, dataset_config: str, threads: List[int]):
dataset_xnli = load_dataset(dataset, dataset_config)
# input_lengths = [(10, False), (10_000, False), (10_000, True)] # Example input lengths
input_lengths = [(10_000, False, True), (10_000, False, False)]
input_lengths = [(10, False, True), (10_000, False, True), (10_000, False, False)]
for num_threads in threads:
for length, fuse, long in input_lengths: