Rust - Fix optional parallelism with par_bridge

This commit is contained in:
Anthony MOI
2020-06-22 16:17:07 -04:00
parent dce52621c6
commit 5d20322319
7 changed files with 92 additions and 42 deletions

View File

@@ -497,7 +497,7 @@ dependencies = [
[[package]] [[package]]
name = "rayon-cond" name = "rayon-cond"
version = "0.1.0" version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/n1t0/rayon-cond#c56e4f1ded0fcb92eac70e0533703bba3ca2983f"
dependencies = [ dependencies = [
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)", "either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
"itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)", "itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -629,7 +629,7 @@ dependencies = [
"onig 6.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "onig 6.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "rayon-cond 0.1.0 (git+https://github.com/n1t0/rayon-cond)",
"regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)", "regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
"regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)", "regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -793,7 +793,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" "checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" "checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
"checksum rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "db6ce3297f9c85e16621bb8cca38a06779ffc31bb8184e1be4bed2be4678a098" "checksum rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "db6ce3297f9c85e16621bb8cca38a06779ffc31bb8184e1be4bed2be4678a098"
"checksum rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7" "checksum rayon-cond 0.1.0 (git+https://github.com/n1t0/rayon-cond)" = "<none>"
"checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9" "checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9"
"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84" "checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84"
"checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3" "checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3"

View File

@@ -47,4 +47,4 @@ tokenizer.train(
) )
# Save the files # Save the files
tokenizer.save(args.out, args.name) tokenizer.save_model(args.out, args.name)

View File

@@ -44,7 +44,7 @@ tokenizer.train(
) )
# Save the files # Save the files
tokenizer.save(args.out, args.name) tokenizer.save_model(args.out, args.name)
# Restoring model from learned vocab/merges # Restoring model from learned vocab/merges
tokenizer = ByteLevelBPETokenizer( tokenizer = ByteLevelBPETokenizer(

View File

@@ -36,7 +36,7 @@ onig = { version = "6.0", default-features = false }
regex = "1.3" regex = "1.3"
regex-syntax = "0.6" regex-syntax = "0.6"
rayon = "1.3" rayon = "1.3"
rayon-cond = "0.1" rayon-cond = { version = "*", git = "https://github.com/n1t0/rayon-cond" }
serde = { version = "1.0", features = [ "derive" ] } serde = { version = "1.0", features = [ "derive" ] }
serde_json = "1.0" serde_json = "1.0"
typetag = "0.1" typetag = "0.1"

View File

@@ -379,11 +379,11 @@ impl BpeTrainer {
h h
}); });
*pair_counts.get_mut(&cur_pair).unwrap() += count as i32; *pair_counts.get_mut(&cur_pair).unwrap() += count as i32;
}
if let Some(p) = &p { if let Some(p) = &p {
p.inc(1); p.inc(1);
} }
}
(pair_counts, where_to_update) (pair_counts, where_to_update)
}) })

View File

@@ -613,11 +613,8 @@ impl Tokenizer {
// We read new lines using this API instead of the Lines Iterator // We read new lines using this API instead of the Lines Iterator
// on purpose. We want to keep the `\n` and potential `\r` between each lines // on purpose. We want to keep the `\n` and potential `\r` between each lines
// We use an iterator to be able to chain with par_bridge. // We use an iterator to be able to chain with par_bridge.
use rayon::prelude::*; file.lines_with_ending()
let words = file .maybe_par_bridge()
.lines_with_ending()
//.maybe_par_bridge()
.par_bridge()
.map_with( .map_with(
&progress, &progress,
|progress, line| -> Result<HashMap<String, u32>> { |progress, line| -> Result<HashMap<String, u32>> {
@@ -638,14 +635,16 @@ impl Tokenizer {
Ok(words) Ok(words)
}, },
) )
.try_reduce(HashMap::new, |mut acc, ws| { .reduce(
for (k, v) in ws { || Ok(HashMap::new()),
|acc, ws| {
let mut acc = acc?;
for (k, v) in ws? {
acc.entry(k).and_modify(|c| *c += v).or_insert(v); acc.entry(k).and_modify(|c| *c += v).or_insert(v);
} }
Ok(acc) Ok(acc)
})?; },
)
Ok(words)
}) })
.try_fold( .try_fold(
HashMap::new(), HashMap::new(),

View File

@@ -6,12 +6,44 @@ use rayon::iter::IterBridge;
use rayon::prelude::*; use rayon::prelude::*;
use rayon_cond::CondIterator; use rayon_cond::CondIterator;
/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
pub fn get_parallelism() -> bool {
match std::env::var("TOKENIZERS_PARALLELISM") {
Ok(mut v) => {
v.make_ascii_lowercase();
match v.as_ref() {
"" | "off" | "false" | "f" | "no" | "n" | "0" => false,
_ => true,
}
}
Err(_) => true, // If we couldn't get the variable, we use the default
}
}
/// Set the value for `TOKENIZERS_PARALLELISM` for the current process
pub fn set_parallelism(val: bool) {
std::env::set_var("TOKENIZERS_PARALLELISM", if val { "true" } else { "false" })
}
/// Allows to convert into an iterator that can be executed either parallelly or serially.
///
/// The choice is made according to the currently set `TOKENIZERS_PARALLELISM` environment variable.
/// This variable can have one of the following values
/// - False => "" (empty value), "false", "f", "off", "no", "n", "0"
/// - True => Any other value
///
pub trait MaybeParallelIterator<P, S> pub trait MaybeParallelIterator<P, S>
where where
P: ParallelIterator, P: ParallelIterator,
S: Iterator<Item = P::Item>, S: Iterator<Item = P::Item>,
{ {
/// Convert ourself in a CondIterator, that will be executed either in parallel or serially,
/// based solely on the `TOKENIZERS_PARALLELISM` environment variable
fn into_maybe_par_iter(self) -> CondIterator<P, S>; fn into_maybe_par_iter(self) -> CondIterator<P, S>;
/// Convert ourself in a CondIterator, that will be executed either in parallel or serially,
/// based on both the `TOKENIZERS_PARALLELISM` environment variable and the provided bool.
/// Both must be true to run with parallelism activated.
fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S>;
} }
impl<P, S, I> MaybeParallelIterator<P, S> for I impl<P, S, I> MaybeParallelIterator<P, S> for I
@@ -21,14 +53,20 @@ where
S: Iterator<Item = P::Item>, S: Iterator<Item = P::Item>,
{ {
fn into_maybe_par_iter(self) -> CondIterator<P, S> { fn into_maybe_par_iter(self) -> CondIterator<P, S> {
// TODO: Define parallelism using std::env CondIterator::new(self, get_parallelism())
// Maybe also add another method that takes a bool to limit parallelism when there are }
// enough elements to process
let parallelism = true; fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S> {
CondIterator::new(self, parallelism) if cond {
self.into_maybe_par_iter()
} else {
CondIterator::from_serial(self)
}
} }
} }
/// Shared reference version of MaybeParallelIterator, works the same but returns an iterator
/// over references, does not consume self
pub trait MaybeParallelRefIterator<'data, P, S> pub trait MaybeParallelRefIterator<'data, P, S>
where where
P: ParallelIterator, P: ParallelIterator,
@@ -36,6 +74,7 @@ where
P::Item: 'data, P::Item: 'data,
{ {
fn maybe_par_iter(&'data self) -> CondIterator<P, S>; fn maybe_par_iter(&'data self) -> CondIterator<P, S>;
fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S>;
} }
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I
@@ -48,8 +87,14 @@ where
fn maybe_par_iter(&'data self) -> CondIterator<P, S> { fn maybe_par_iter(&'data self) -> CondIterator<P, S> {
self.into_maybe_par_iter() self.into_maybe_par_iter()
} }
fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S> {
self.into_maybe_par_iter_cond(cond)
}
} }
/// Exclusive reference version of MaybeParallelIterator, works the same but returns an iterator
/// over mutable references, does not consume self
pub trait MaybeParallelRefMutIterator<'data, P, S> pub trait MaybeParallelRefMutIterator<'data, P, S>
where where
P: ParallelIterator, P: ParallelIterator,
@@ -57,6 +102,7 @@ where
P::Item: 'data, P::Item: 'data,
{ {
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S>; fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S>;
fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S>;
} }
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I
@@ -69,14 +115,20 @@ where
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S> { fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S> {
self.into_maybe_par_iter() self.into_maybe_par_iter()
} }
fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S> {
self.into_maybe_par_iter_cond(cond)
}
} }
/// Converts any serial iterator into a CondIterator, that can either run parallelly or serially.
pub trait MaybeParallelBridge<T, S> pub trait MaybeParallelBridge<T, S>
where where
S: Iterator<Item = T> + Send, S: Iterator<Item = T> + Send,
T: Send, T: Send,
{ {
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S>; fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S>;
fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S>;
} }
impl<T, S> MaybeParallelBridge<T, S> for S impl<T, S> MaybeParallelBridge<T, S> for S
@@ -86,14 +138,21 @@ where
{ {
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S> { fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S> {
let iter = CondIterator::from_serial(self); let iter = CondIterator::from_serial(self);
let parallelism = true;
if parallelism { if get_parallelism() {
CondIterator::from_parallel(iter.into_parallel().right().unwrap()) CondIterator::from_parallel(iter.into_parallel().right().unwrap())
} else { } else {
iter iter
} }
} }
fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S> {
if cond {
self.maybe_par_bridge()
} else {
CondIterator::from_serial(self)
}
}
} }
#[cfg(test)] #[cfg(test)]
@@ -103,19 +162,11 @@ mod tests {
#[test] #[test]
#[ignore] #[ignore]
fn test_maybe_parallel_iterator() { fn test_maybe_parallel_iterator() {
let mut v = vec![1, 2, 3, 4, 5, 6]; let mut v = vec![1u32, 2, 3, 4, 5, 6];
let iter = v.par_iter(); assert_eq!(v.maybe_par_iter().sum::<u32>(), 21);
let iter = (&mut v).into_maybe_par_iter(); assert_eq!(v.maybe_par_iter_mut().map(|v| *v * 2).sum::<u32>(), 42);
let iter = v.maybe_par_iter(); assert_eq!(v.maybe_par_iter().sum::<u32>(), 42);
let iter = v.iter().maybe_par_bridge(); assert_eq!(v.into_maybe_par_iter().sum::<u32>(), 42);
let iter = v.maybe_par_iter_mut().for_each(|item| {
*item *= 2;
println!("{}", item)
});
let iter = (&mut v).maybe_par_iter_mut();
let iter = v.into_iter().par_bridge();
panic!();
} }
} }