Rust - Make parallelism optional

This commit is contained in:
Anthony MOI
2020-06-19 10:12:01 -04:00
parent 74d812d401
commit dce52621c6
11 changed files with 165 additions and 15 deletions

View File

@ -244,6 +244,14 @@ dependencies = [
"syn 1.0.17 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "itertools"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "itoa"
version = "0.4.5"
@ -486,6 +494,16 @@ dependencies = [
"rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rayon-cond"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
"itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rayon-core"
version = "1.7.0"
@ -611,6 +629,7 @@ dependencies = [
"onig 6.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
"regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)",
@ -744,6 +763,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum indoc-impl 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)" = "54554010aa3d17754e484005ea0022f1c93839aabc627c2c55f3d7b47206134c"
"checksum inventory 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "82d3f4b90287725c97b17478c60dda0c6324e7c84ee1ed72fb9179d0fdf13956"
"checksum inventory-impl 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "9092a4fefc9d503e9287ef137f03180a6e7d1b04c419563171ee14947c5e80ec"
"checksum itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484"
"checksum itoa 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)" = "b8b7a7c0c47db5545ed3fef7468ee7bb5b74691498139e4b3f6a20685dc6dd8e"
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
"checksum libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)" = "dea0c0405123bba743ee3f91f49b1c7cfb684eef0da0a50110f758ccf24cdff0"
@ -773,6 +793,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
"checksum rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "db6ce3297f9c85e16621bb8cca38a06779ffc31bb8184e1be4bed2be4678a098"
"checksum rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7"
"checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9"
"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84"
"checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3"

View File

@ -6,8 +6,8 @@ use super::utils::Container;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use rayon::prelude::*;
use std::path::Path;
use tk::parallelism::*;
#[pyclass]
struct EncodeInput {
@ -154,7 +154,7 @@ impl Model {
fn encode_batch(&self, sequences: Vec<EncodeInput>, type_id: u32) -> PyResult<Vec<Encoding>> {
ToPyResult(self.model.execute(|model| {
sequences
.into_par_iter()
.into_maybe_par_iter()
.map(|sequence| {
let sequence = sequence.into_input();
if sequence.is_empty() {

View File

@ -36,6 +36,7 @@ onig = { version = "6.0", default-features = false }
regex = "1.3"
regex-syntax = "0.6"
rayon = "1.3"
rayon-cond = "0.1"
serde = { version = "1.0", features = [ "derive" ] }
serde_json = "1.0"
typetag = "0.1"

View File

@ -57,3 +57,6 @@ pub mod utils;
// Re-export from tokenizer
pub use tokenizer::*;
// Re-export also parallelism utils
pub use utils::parallelism;

View File

@ -1,9 +1,9 @@
#![allow(clippy::map_entry)]
use super::{Pair, WithFirstLastIterator, Word, BPE};
use crate::parallelism::*;
use crate::tokenizer::{AddedToken, Model, Result, Trainer};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::prelude::*;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
@ -352,7 +352,7 @@ impl BpeTrainer {
p: &Option<ProgressBar>,
) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
words
.par_iter()
.maybe_par_iter()
.enumerate()
.map(|(i, word)| {
let mut pair_counts = HashMap::new();
@ -499,7 +499,7 @@ impl BpeTrainer {
// Merge the new pair in every words
let changes = top
.pos
.par_iter()
.maybe_par_iter()
.flat_map(|i| {
let w = &words[*i] as *const _ as *mut _;
// We can merge each of these words in parallel here because each position

View File

@ -1,8 +1,8 @@
use crate::parallelism::*;
use crate::tokenizer::{
Decoder, Encoding, NormalizedString, Offsets, PostProcessor, PreTokenizer, Result,
};
use onig::Regex;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
@ -97,7 +97,7 @@ impl PreTokenizer for ByteLevel {
.collect::<Vec<_>>();
let splits = positions
.into_par_iter()
.into_maybe_par_iter()
.map(|range| {
// Process one of the splits
let slice = &normalized.get()[range];

View File

@ -1,6 +1,6 @@
use crate::parallelism::*;
use crate::tokenizer::{Offsets, Token};
use crate::utils::padding::PaddingDirection;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
/// Represents the output of a `Tokenizer`.
@ -362,7 +362,7 @@ impl Encoding {
direction: PaddingDirection,
) {
// Dispatch call to all the overflowings first
self.overflowing.par_iter_mut().for_each(|encoding| {
self.overflowing.maybe_par_iter_mut().for_each(|encoding| {
encoding.pad(target_length, pad_id, pad_type_id, pad_token, direction)
});

View File

@ -14,7 +14,6 @@ use crate::utils::iter::ResultShunt;
pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy};
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::prelude::*;
use std::{
collections::HashMap,
fs::File,
@ -36,6 +35,8 @@ pub type Error = Box<dyn std::error::Error + Send + Sync>;
pub type Result<T> = std::result::Result<T, Error>;
pub type Offsets = (usize, usize);
use crate::utils::parallelism::*;
#[typetag::serde(tag = "type")]
/// Takes care of pre-processing strings.
pub trait Normalizer: Send + Sync {
@ -532,7 +533,7 @@ impl Tokenizer {
add_special_tokens: bool,
) -> Result<Vec<Encoding>> {
let mut encodings = inputs
.into_par_iter()
.into_maybe_par_iter()
.map(|input| self.encode(input, add_special_tokens))
.collect::<Result<Vec<Encoding>>>()?;
@ -574,7 +575,7 @@ impl Tokenizer {
skip_special_tokens: bool,
) -> Result<Vec<String>> {
sentences
.into_par_iter()
.into_maybe_par_iter()
.map(|sentence| self.decode(sentence, skip_special_tokens))
.collect()
}
@ -612,8 +613,10 @@ impl Tokenizer {
// We read new lines using this API instead of the Lines Iterator
// on purpose. We want to keep the `\n` and potential `\r` between each lines
// We use an iterator to be able to chain with par_bridge.
use rayon::prelude::*;
let words = file
.lines_with_ending()
//.maybe_par_bridge()
.par_bridge()
.map_with(
&progress,

View File

@ -1,3 +1,4 @@
pub mod iter;
pub mod padding;
pub mod parallelism;
pub mod truncation;

View File

@ -1,5 +1,5 @@
use crate::parallelism::*;
use crate::tokenizer::{Encoding, Result};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
/// The various possible padding directions.
@ -55,7 +55,7 @@ pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Resu
let mut pad_length = match params.strategy {
PaddingStrategy::Fixed(size) => size,
PaddingStrategy::BatchLongest => encodings
.par_iter()
.maybe_par_iter()
.map(|e| e.get_ids().len())
.max()
.unwrap(),
@ -67,7 +67,7 @@ pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Resu
}
}
encodings.par_iter_mut().for_each(|encoding| {
encodings.maybe_par_iter_mut().for_each(|encoding| {
encoding.pad(
pad_length,
params.pad_id,

View File

@ -0,0 +1,121 @@
//!
//! This module defines helpers to allow optional Rayon usage.
//!
use rayon::iter::IterBridge;
use rayon::prelude::*;
use rayon_cond::CondIterator;
pub trait MaybeParallelIterator<P, S>
where
P: ParallelIterator,
S: Iterator<Item = P::Item>,
{
fn into_maybe_par_iter(self) -> CondIterator<P, S>;
}
impl<P, S, I> MaybeParallelIterator<P, S> for I
where
I: IntoParallelIterator<Iter = P, Item = P::Item> + IntoIterator<IntoIter = S, Item = S::Item>,
P: ParallelIterator,
S: Iterator<Item = P::Item>,
{
fn into_maybe_par_iter(self) -> CondIterator<P, S> {
// TODO: Define parallelism using std::env
// Maybe also add another method that takes a bool to limit parallelism when there are
// enough elements to process
let parallelism = true;
CondIterator::new(self, parallelism)
}
}
pub trait MaybeParallelRefIterator<'data, P, S>
where
P: ParallelIterator,
S: Iterator<Item = P::Item>,
P::Item: 'data,
{
fn maybe_par_iter(&'data self) -> CondIterator<P, S>;
}
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I
where
&'data I: MaybeParallelIterator<P, S>,
P: ParallelIterator,
S: Iterator<Item = P::Item>,
P::Item: 'data,
{
fn maybe_par_iter(&'data self) -> CondIterator<P, S> {
self.into_maybe_par_iter()
}
}
pub trait MaybeParallelRefMutIterator<'data, P, S>
where
P: ParallelIterator,
S: Iterator<Item = P::Item>,
P::Item: 'data,
{
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S>;
}
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I
where
&'data mut I: MaybeParallelIterator<P, S>,
P: ParallelIterator,
S: Iterator<Item = P::Item>,
P::Item: 'data,
{
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S> {
self.into_maybe_par_iter()
}
}
pub trait MaybeParallelBridge<T, S>
where
S: Iterator<Item = T> + Send,
T: Send,
{
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S>;
}
impl<T, S> MaybeParallelBridge<T, S> for S
where
S: Iterator<Item = T> + Send,
T: Send,
{
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S> {
let iter = CondIterator::from_serial(self);
let parallelism = true;
if parallelism {
CondIterator::from_parallel(iter.into_parallel().right().unwrap())
} else {
iter
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore]
fn test_maybe_parallel_iterator() {
let mut v = vec![1, 2, 3, 4, 5, 6];
let iter = v.par_iter();
let iter = (&mut v).into_maybe_par_iter();
let iter = v.maybe_par_iter();
let iter = v.iter().maybe_par_bridge();
let iter = v.maybe_par_iter_mut().for_each(|item| {
*item *= 2;
println!("{}", item)
});
let iter = (&mut v).maybe_par_iter_mut();
let iter = v.into_iter().par_bridge();
panic!();
}
}