mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-28 19:19:34 +00:00
Rust - Make parallelism optional
This commit is contained in:
21
bindings/python/Cargo.lock
generated
21
bindings/python/Cargo.lock
generated
@ -244,6 +244,14 @@ dependencies = [
|
||||
"syn 1.0.17 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "0.4.5"
|
||||
@ -486,6 +494,16 @@ dependencies = [
|
||||
"rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-cond"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.7.0"
|
||||
@ -611,6 +629,7 @@ dependencies = [
|
||||
"onig 6.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
@ -744,6 +763,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum indoc-impl 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)" = "54554010aa3d17754e484005ea0022f1c93839aabc627c2c55f3d7b47206134c"
|
||||
"checksum inventory 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "82d3f4b90287725c97b17478c60dda0c6324e7c84ee1ed72fb9179d0fdf13956"
|
||||
"checksum inventory-impl 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "9092a4fefc9d503e9287ef137f03180a6e7d1b04c419563171ee14947c5e80ec"
|
||||
"checksum itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484"
|
||||
"checksum itoa 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)" = "b8b7a7c0c47db5545ed3fef7468ee7bb5b74691498139e4b3f6a20685dc6dd8e"
|
||||
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
"checksum libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)" = "dea0c0405123bba743ee3f91f49b1c7cfb684eef0da0a50110f758ccf24cdff0"
|
||||
@ -773,6 +793,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
"checksum rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "db6ce3297f9c85e16621bb8cca38a06779ffc31bb8184e1be4bed2be4678a098"
|
||||
"checksum rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7"
|
||||
"checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9"
|
||||
"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84"
|
||||
"checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3"
|
||||
|
@ -6,8 +6,8 @@ use super::utils::Container;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use rayon::prelude::*;
|
||||
use std::path::Path;
|
||||
use tk::parallelism::*;
|
||||
|
||||
#[pyclass]
|
||||
struct EncodeInput {
|
||||
@ -154,7 +154,7 @@ impl Model {
|
||||
fn encode_batch(&self, sequences: Vec<EncodeInput>, type_id: u32) -> PyResult<Vec<Encoding>> {
|
||||
ToPyResult(self.model.execute(|model| {
|
||||
sequences
|
||||
.into_par_iter()
|
||||
.into_maybe_par_iter()
|
||||
.map(|sequence| {
|
||||
let sequence = sequence.into_input();
|
||||
if sequence.is_empty() {
|
||||
|
@ -36,6 +36,7 @@ onig = { version = "6.0", default-features = false }
|
||||
regex = "1.3"
|
||||
regex-syntax = "0.6"
|
||||
rayon = "1.3"
|
||||
rayon-cond = "0.1"
|
||||
serde = { version = "1.0", features = [ "derive" ] }
|
||||
serde_json = "1.0"
|
||||
typetag = "0.1"
|
||||
|
@ -57,3 +57,6 @@ pub mod utils;
|
||||
|
||||
// Re-export from tokenizer
|
||||
pub use tokenizer::*;
|
||||
|
||||
// Re-export also parallelism utils
|
||||
pub use utils::parallelism;
|
||||
|
@ -1,9 +1,9 @@
|
||||
#![allow(clippy::map_entry)]
|
||||
|
||||
use super::{Pair, WithFirstLastIterator, Word, BPE};
|
||||
use crate::parallelism::*;
|
||||
use crate::tokenizer::{AddedToken, Model, Result, Trainer};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use rayon::prelude::*;
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::{BinaryHeap, HashMap, HashSet};
|
||||
|
||||
@ -352,7 +352,7 @@ impl BpeTrainer {
|
||||
p: &Option<ProgressBar>,
|
||||
) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
|
||||
words
|
||||
.par_iter()
|
||||
.maybe_par_iter()
|
||||
.enumerate()
|
||||
.map(|(i, word)| {
|
||||
let mut pair_counts = HashMap::new();
|
||||
@ -499,7 +499,7 @@ impl BpeTrainer {
|
||||
// Merge the new pair in every words
|
||||
let changes = top
|
||||
.pos
|
||||
.par_iter()
|
||||
.maybe_par_iter()
|
||||
.flat_map(|i| {
|
||||
let w = &words[*i] as *const _ as *mut _;
|
||||
// We can merge each of these words in parallel here because each position
|
||||
|
@ -1,8 +1,8 @@
|
||||
use crate::parallelism::*;
|
||||
use crate::tokenizer::{
|
||||
Decoder, Encoding, NormalizedString, Offsets, PostProcessor, PreTokenizer, Result,
|
||||
};
|
||||
use onig::Regex;
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
@ -97,7 +97,7 @@ impl PreTokenizer for ByteLevel {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let splits = positions
|
||||
.into_par_iter()
|
||||
.into_maybe_par_iter()
|
||||
.map(|range| {
|
||||
// Process one of the splits
|
||||
let slice = &normalized.get()[range];
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::parallelism::*;
|
||||
use crate::tokenizer::{Offsets, Token};
|
||||
use crate::utils::padding::PaddingDirection;
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Represents the output of a `Tokenizer`.
|
||||
@ -362,7 +362,7 @@ impl Encoding {
|
||||
direction: PaddingDirection,
|
||||
) {
|
||||
// Dispatch call to all the overflowings first
|
||||
self.overflowing.par_iter_mut().for_each(|encoding| {
|
||||
self.overflowing.maybe_par_iter_mut().for_each(|encoding| {
|
||||
encoding.pad(target_length, pad_id, pad_type_id, pad_token, direction)
|
||||
});
|
||||
|
||||
|
@ -14,7 +14,6 @@ use crate::utils::iter::ResultShunt;
|
||||
pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy};
|
||||
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use rayon::prelude::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs::File,
|
||||
@ -36,6 +35,8 @@ pub type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
pub type Offsets = (usize, usize);
|
||||
|
||||
use crate::utils::parallelism::*;
|
||||
|
||||
#[typetag::serde(tag = "type")]
|
||||
/// Takes care of pre-processing strings.
|
||||
pub trait Normalizer: Send + Sync {
|
||||
@ -532,7 +533,7 @@ impl Tokenizer {
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Vec<Encoding>> {
|
||||
let mut encodings = inputs
|
||||
.into_par_iter()
|
||||
.into_maybe_par_iter()
|
||||
.map(|input| self.encode(input, add_special_tokens))
|
||||
.collect::<Result<Vec<Encoding>>>()?;
|
||||
|
||||
@ -574,7 +575,7 @@ impl Tokenizer {
|
||||
skip_special_tokens: bool,
|
||||
) -> Result<Vec<String>> {
|
||||
sentences
|
||||
.into_par_iter()
|
||||
.into_maybe_par_iter()
|
||||
.map(|sentence| self.decode(sentence, skip_special_tokens))
|
||||
.collect()
|
||||
}
|
||||
@ -612,8 +613,10 @@ impl Tokenizer {
|
||||
// We read new lines using this API instead of the Lines Iterator
|
||||
// on purpose. We want to keep the `\n` and potential `\r` between each lines
|
||||
// We use an iterator to be able to chain with par_bridge.
|
||||
use rayon::prelude::*;
|
||||
let words = file
|
||||
.lines_with_ending()
|
||||
//.maybe_par_bridge()
|
||||
.par_bridge()
|
||||
.map_with(
|
||||
&progress,
|
||||
|
@ -1,3 +1,4 @@
|
||||
pub mod iter;
|
||||
pub mod padding;
|
||||
pub mod parallelism;
|
||||
pub mod truncation;
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::parallelism::*;
|
||||
use crate::tokenizer::{Encoding, Result};
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The various possible padding directions.
|
||||
@ -55,7 +55,7 @@ pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Resu
|
||||
let mut pad_length = match params.strategy {
|
||||
PaddingStrategy::Fixed(size) => size,
|
||||
PaddingStrategy::BatchLongest => encodings
|
||||
.par_iter()
|
||||
.maybe_par_iter()
|
||||
.map(|e| e.get_ids().len())
|
||||
.max()
|
||||
.unwrap(),
|
||||
@ -67,7 +67,7 @@ pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Resu
|
||||
}
|
||||
}
|
||||
|
||||
encodings.par_iter_mut().for_each(|encoding| {
|
||||
encodings.maybe_par_iter_mut().for_each(|encoding| {
|
||||
encoding.pad(
|
||||
pad_length,
|
||||
params.pad_id,
|
||||
|
121
tokenizers/src/utils/parallelism.rs
Normal file
121
tokenizers/src/utils/parallelism.rs
Normal file
@ -0,0 +1,121 @@
|
||||
//!
|
||||
//! This module defines helpers to allow optional Rayon usage.
|
||||
//!
|
||||
|
||||
use rayon::iter::IterBridge;
|
||||
use rayon::prelude::*;
|
||||
use rayon_cond::CondIterator;
|
||||
|
||||
pub trait MaybeParallelIterator<P, S>
|
||||
where
|
||||
P: ParallelIterator,
|
||||
S: Iterator<Item = P::Item>,
|
||||
{
|
||||
fn into_maybe_par_iter(self) -> CondIterator<P, S>;
|
||||
}
|
||||
|
||||
impl<P, S, I> MaybeParallelIterator<P, S> for I
|
||||
where
|
||||
I: IntoParallelIterator<Iter = P, Item = P::Item> + IntoIterator<IntoIter = S, Item = S::Item>,
|
||||
P: ParallelIterator,
|
||||
S: Iterator<Item = P::Item>,
|
||||
{
|
||||
fn into_maybe_par_iter(self) -> CondIterator<P, S> {
|
||||
// TODO: Define parallelism using std::env
|
||||
// Maybe also add another method that takes a bool to limit parallelism when there are
|
||||
// enough elements to process
|
||||
let parallelism = true;
|
||||
CondIterator::new(self, parallelism)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MaybeParallelRefIterator<'data, P, S>
|
||||
where
|
||||
P: ParallelIterator,
|
||||
S: Iterator<Item = P::Item>,
|
||||
P::Item: 'data,
|
||||
{
|
||||
fn maybe_par_iter(&'data self) -> CondIterator<P, S>;
|
||||
}
|
||||
|
||||
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I
|
||||
where
|
||||
&'data I: MaybeParallelIterator<P, S>,
|
||||
P: ParallelIterator,
|
||||
S: Iterator<Item = P::Item>,
|
||||
P::Item: 'data,
|
||||
{
|
||||
fn maybe_par_iter(&'data self) -> CondIterator<P, S> {
|
||||
self.into_maybe_par_iter()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MaybeParallelRefMutIterator<'data, P, S>
|
||||
where
|
||||
P: ParallelIterator,
|
||||
S: Iterator<Item = P::Item>,
|
||||
P::Item: 'data,
|
||||
{
|
||||
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S>;
|
||||
}
|
||||
|
||||
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I
|
||||
where
|
||||
&'data mut I: MaybeParallelIterator<P, S>,
|
||||
P: ParallelIterator,
|
||||
S: Iterator<Item = P::Item>,
|
||||
P::Item: 'data,
|
||||
{
|
||||
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S> {
|
||||
self.into_maybe_par_iter()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MaybeParallelBridge<T, S>
|
||||
where
|
||||
S: Iterator<Item = T> + Send,
|
||||
T: Send,
|
||||
{
|
||||
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S>;
|
||||
}
|
||||
|
||||
impl<T, S> MaybeParallelBridge<T, S> for S
|
||||
where
|
||||
S: Iterator<Item = T> + Send,
|
||||
T: Send,
|
||||
{
|
||||
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S> {
|
||||
let iter = CondIterator::from_serial(self);
|
||||
let parallelism = true;
|
||||
|
||||
if parallelism {
|
||||
CondIterator::from_parallel(iter.into_parallel().right().unwrap())
|
||||
} else {
|
||||
iter
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_maybe_parallel_iterator() {
|
||||
let mut v = vec![1, 2, 3, 4, 5, 6];
|
||||
|
||||
let iter = v.par_iter();
|
||||
let iter = (&mut v).into_maybe_par_iter();
|
||||
let iter = v.maybe_par_iter();
|
||||
let iter = v.iter().maybe_par_bridge();
|
||||
let iter = v.maybe_par_iter_mut().for_each(|item| {
|
||||
*item *= 2;
|
||||
println!("{}", item)
|
||||
});
|
||||
let iter = (&mut v).maybe_par_iter_mut();
|
||||
let iter = v.into_iter().par_bridge();
|
||||
|
||||
panic!();
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user