diff --git a/bindings/python/Cargo.lock b/bindings/python/Cargo.lock index a5b4b95f..8fb6fe3b 100644 --- a/bindings/python/Cargo.lock +++ b/bindings/python/Cargo.lock @@ -244,6 +244,14 @@ dependencies = [ "syn 1.0.17 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "itertools" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "itoa" version = "0.4.5" @@ -486,6 +494,16 @@ dependencies = [ "rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "rayon-cond" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)", + "itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)", + "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "rayon-core" version = "1.7.0" @@ -611,6 +629,7 @@ dependencies = [ "onig 6.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)", "regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)", @@ -744,6 +763,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum indoc-impl 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)" = "54554010aa3d17754e484005ea0022f1c93839aabc627c2c55f3d7b47206134c" "checksum inventory 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "82d3f4b90287725c97b17478c60dda0c6324e7c84ee1ed72fb9179d0fdf13956" "checksum inventory-impl 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "9092a4fefc9d503e9287ef137f03180a6e7d1b04c419563171ee14947c5e80ec" +"checksum itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484" "checksum itoa 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)" = "b8b7a7c0c47db5545ed3fef7468ee7bb5b74691498139e4b3f6a20685dc6dd8e" "checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" "checksum libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)" = "dea0c0405123bba743ee3f91f49b1c7cfb684eef0da0a50110f758ccf24cdff0" @@ -773,6 +793,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" "checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" "checksum rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "db6ce3297f9c85e16621bb8cca38a06779ffc31bb8184e1be4bed2be4678a098" +"checksum rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7" "checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9" "checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84" "checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3" diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 677e6608..8214c441 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -6,8 +6,8 @@ use super::utils::Container; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use rayon::prelude::*; use std::path::Path; +use tk::parallelism::*; #[pyclass] struct EncodeInput { @@ -154,7 +154,7 @@ impl Model { fn encode_batch(&self, sequences: Vec, type_id: u32) -> PyResult> { ToPyResult(self.model.execute(|model| { sequences - .into_par_iter() + .into_maybe_par_iter() .map(|sequence| { let sequence = sequence.into_input(); if sequence.is_empty() { diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 2d11ac9b..9e7ab98a 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -36,6 +36,7 @@ onig = { version = "6.0", default-features = false } regex = "1.3" regex-syntax = "0.6" rayon = "1.3" +rayon-cond = "0.1" serde = { version = "1.0", features = [ "derive" ] } serde_json = "1.0" typetag = "0.1" diff --git a/tokenizers/src/lib.rs b/tokenizers/src/lib.rs index 437e61db..74e3eb2d 100644 --- a/tokenizers/src/lib.rs +++ b/tokenizers/src/lib.rs @@ -57,3 +57,6 @@ pub mod utils; // Re-export from tokenizer pub use tokenizer::*; + +// Re-export also parallelism utils +pub use utils::parallelism; diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 44e24523..47ba790e 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -1,9 +1,9 @@ #![allow(clippy::map_entry)] use super::{Pair, WithFirstLastIterator, Word, BPE}; +use crate::parallelism::*; use crate::tokenizer::{AddedToken, Model, Result, Trainer}; use indicatif::{ProgressBar, ProgressStyle}; -use rayon::prelude::*; use std::cmp::Ordering; use std::collections::{BinaryHeap, HashMap, HashSet}; @@ -352,7 +352,7 @@ impl BpeTrainer { p: &Option, ) -> (HashMap, HashMap>) { words - .par_iter() + .maybe_par_iter() .enumerate() .map(|(i, word)| { let mut pair_counts = HashMap::new(); @@ -499,7 +499,7 @@ impl BpeTrainer { // Merge the new pair in every words let changes = top .pos - .par_iter() + .maybe_par_iter() .flat_map(|i| { let w = &words[*i] as *const _ as *mut _; // We can merge each of these words in parallel here because each position diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index c082ac9b..096900ff 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,8 +1,8 @@ +use crate::parallelism::*; use crate::tokenizer::{ Decoder, Encoding, NormalizedString, Offsets, PostProcessor, PreTokenizer, Result, }; use onig::Regex; -use rayon::prelude::*; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; @@ -97,7 +97,7 @@ impl PreTokenizer for ByteLevel { .collect::>(); let splits = positions - .into_par_iter() + .into_maybe_par_iter() .map(|range| { // Process one of the splits let slice = &normalized.get()[range]; diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index 390d864b..d280a16b 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -1,6 +1,6 @@ +use crate::parallelism::*; use crate::tokenizer::{Offsets, Token}; use crate::utils::padding::PaddingDirection; -use rayon::prelude::*; use serde::{Deserialize, Serialize}; /// Represents the output of a `Tokenizer`. @@ -362,7 +362,7 @@ impl Encoding { direction: PaddingDirection, ) { // Dispatch call to all the overflowings first - self.overflowing.par_iter_mut().for_each(|encoding| { + self.overflowing.maybe_par_iter_mut().for_each(|encoding| { encoding.pad(target_length, pad_id, pad_type_id, pad_token, direction) }); diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index be98a821..c99cacb8 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -14,7 +14,6 @@ use crate::utils::iter::ResultShunt; pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy}; pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy}; use indicatif::{ProgressBar, ProgressStyle}; -use rayon::prelude::*; use std::{ collections::HashMap, fs::File, @@ -36,6 +35,8 @@ pub type Error = Box; pub type Result = std::result::Result; pub type Offsets = (usize, usize); +use crate::utils::parallelism::*; + #[typetag::serde(tag = "type")] /// Takes care of pre-processing strings. pub trait Normalizer: Send + Sync { @@ -532,7 +533,7 @@ impl Tokenizer { add_special_tokens: bool, ) -> Result> { let mut encodings = inputs - .into_par_iter() + .into_maybe_par_iter() .map(|input| self.encode(input, add_special_tokens)) .collect::>>()?; @@ -574,7 +575,7 @@ impl Tokenizer { skip_special_tokens: bool, ) -> Result> { sentences - .into_par_iter() + .into_maybe_par_iter() .map(|sentence| self.decode(sentence, skip_special_tokens)) .collect() } @@ -612,8 +613,10 @@ impl Tokenizer { // We read new lines using this API instead of the Lines Iterator // on purpose. We want to keep the `\n` and potential `\r` between each lines // We use an iterator to be able to chain with par_bridge. + use rayon::prelude::*; let words = file .lines_with_ending() + //.maybe_par_bridge() .par_bridge() .map_with( &progress, diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs index f961fd7d..e641102e 100644 --- a/tokenizers/src/utils/mod.rs +++ b/tokenizers/src/utils/mod.rs @@ -1,3 +1,4 @@ pub mod iter; pub mod padding; +pub mod parallelism; pub mod truncation; diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index 9d03df10..07cbf1fc 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -1,5 +1,5 @@ +use crate::parallelism::*; use crate::tokenizer::{Encoding, Result}; -use rayon::prelude::*; use serde::{Deserialize, Serialize}; /// The various possible padding directions. @@ -55,7 +55,7 @@ pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Resu let mut pad_length = match params.strategy { PaddingStrategy::Fixed(size) => size, PaddingStrategy::BatchLongest => encodings - .par_iter() + .maybe_par_iter() .map(|e| e.get_ids().len()) .max() .unwrap(), @@ -67,7 +67,7 @@ pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Resu } } - encodings.par_iter_mut().for_each(|encoding| { + encodings.maybe_par_iter_mut().for_each(|encoding| { encoding.pad( pad_length, params.pad_id, diff --git a/tokenizers/src/utils/parallelism.rs b/tokenizers/src/utils/parallelism.rs new file mode 100644 index 00000000..3e324abc --- /dev/null +++ b/tokenizers/src/utils/parallelism.rs @@ -0,0 +1,121 @@ +//! +//! This module defines helpers to allow optional Rayon usage. +//! + +use rayon::iter::IterBridge; +use rayon::prelude::*; +use rayon_cond::CondIterator; + +pub trait MaybeParallelIterator +where + P: ParallelIterator, + S: Iterator, +{ + fn into_maybe_par_iter(self) -> CondIterator; +} + +impl MaybeParallelIterator for I +where + I: IntoParallelIterator + IntoIterator, + P: ParallelIterator, + S: Iterator, +{ + fn into_maybe_par_iter(self) -> CondIterator { + // TODO: Define parallelism using std::env + // Maybe also add another method that takes a bool to limit parallelism when there are + // enough elements to process + let parallelism = true; + CondIterator::new(self, parallelism) + } +} + +pub trait MaybeParallelRefIterator<'data, P, S> +where + P: ParallelIterator, + S: Iterator, + P::Item: 'data, +{ + fn maybe_par_iter(&'data self) -> CondIterator; +} + +impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I +where + &'data I: MaybeParallelIterator, + P: ParallelIterator, + S: Iterator, + P::Item: 'data, +{ + fn maybe_par_iter(&'data self) -> CondIterator { + self.into_maybe_par_iter() + } +} + +pub trait MaybeParallelRefMutIterator<'data, P, S> +where + P: ParallelIterator, + S: Iterator, + P::Item: 'data, +{ + fn maybe_par_iter_mut(&'data mut self) -> CondIterator; +} + +impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I +where + &'data mut I: MaybeParallelIterator, + P: ParallelIterator, + S: Iterator, + P::Item: 'data, +{ + fn maybe_par_iter_mut(&'data mut self) -> CondIterator { + self.into_maybe_par_iter() + } +} + +pub trait MaybeParallelBridge +where + S: Iterator + Send, + T: Send, +{ + fn maybe_par_bridge(self) -> CondIterator, S>; +} + +impl MaybeParallelBridge for S +where + S: Iterator + Send, + T: Send, +{ + fn maybe_par_bridge(self) -> CondIterator, S> { + let iter = CondIterator::from_serial(self); + let parallelism = true; + + if parallelism { + CondIterator::from_parallel(iter.into_parallel().right().unwrap()) + } else { + iter + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[ignore] + fn test_maybe_parallel_iterator() { + let mut v = vec![1, 2, 3, 4, 5, 6]; + + let iter = v.par_iter(); + let iter = (&mut v).into_maybe_par_iter(); + let iter = v.maybe_par_iter(); + let iter = v.iter().maybe_par_bridge(); + let iter = v.maybe_par_iter_mut().for_each(|item| { + *item *= 2; + println!("{}", item) + }); + let iter = (&mut v).maybe_par_iter_mut(); + let iter = v.into_iter().par_bridge(); + + panic!(); + } +}