mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-16 17:18:43 +00:00
Rust - Fix optional parallelism with par_bridge
This commit is contained in:
6
bindings/python/Cargo.lock
generated
6
bindings/python/Cargo.lock
generated
@@ -497,7 +497,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rayon-cond"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
source = "git+https://github.com/n1t0/rayon-cond#c56e4f1ded0fcb92eac70e0533703bba3ca2983f"
|
||||
dependencies = [
|
||||
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
@@ -629,7 +629,7 @@ dependencies = [
|
||||
"onig 6.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon-cond 0.1.0 (git+https://github.com/n1t0/rayon-cond)",
|
||||
"regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
@@ -793,7 +793,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
"checksum rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "db6ce3297f9c85e16621bb8cca38a06779ffc31bb8184e1be4bed2be4678a098"
|
||||
"checksum rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7"
|
||||
"checksum rayon-cond 0.1.0 (git+https://github.com/n1t0/rayon-cond)" = "<none>"
|
||||
"checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9"
|
||||
"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84"
|
||||
"checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3"
|
||||
|
||||
@@ -47,4 +47,4 @@ tokenizer.train(
|
||||
)
|
||||
|
||||
# Save the files
|
||||
tokenizer.save(args.out, args.name)
|
||||
tokenizer.save_model(args.out, args.name)
|
||||
|
||||
@@ -44,7 +44,7 @@ tokenizer.train(
|
||||
)
|
||||
|
||||
# Save the files
|
||||
tokenizer.save(args.out, args.name)
|
||||
tokenizer.save_model(args.out, args.name)
|
||||
|
||||
# Restoring model from learned vocab/merges
|
||||
tokenizer = ByteLevelBPETokenizer(
|
||||
|
||||
@@ -36,7 +36,7 @@ onig = { version = "6.0", default-features = false }
|
||||
regex = "1.3"
|
||||
regex-syntax = "0.6"
|
||||
rayon = "1.3"
|
||||
rayon-cond = "0.1"
|
||||
rayon-cond = { version = "*", git = "https://github.com/n1t0/rayon-cond" }
|
||||
serde = { version = "1.0", features = [ "derive" ] }
|
||||
serde_json = "1.0"
|
||||
typetag = "0.1"
|
||||
|
||||
@@ -379,11 +379,11 @@ impl BpeTrainer {
|
||||
h
|
||||
});
|
||||
*pair_counts.get_mut(&cur_pair).unwrap() += count as i32;
|
||||
}
|
||||
|
||||
if let Some(p) = &p {
|
||||
p.inc(1);
|
||||
}
|
||||
}
|
||||
|
||||
(pair_counts, where_to_update)
|
||||
})
|
||||
|
||||
@@ -613,11 +613,8 @@ impl Tokenizer {
|
||||
// We read new lines using this API instead of the Lines Iterator
|
||||
// on purpose. We want to keep the `\n` and potential `\r` between each lines
|
||||
// We use an iterator to be able to chain with par_bridge.
|
||||
use rayon::prelude::*;
|
||||
let words = file
|
||||
.lines_with_ending()
|
||||
//.maybe_par_bridge()
|
||||
.par_bridge()
|
||||
file.lines_with_ending()
|
||||
.maybe_par_bridge()
|
||||
.map_with(
|
||||
&progress,
|
||||
|progress, line| -> Result<HashMap<String, u32>> {
|
||||
@@ -638,14 +635,16 @@ impl Tokenizer {
|
||||
Ok(words)
|
||||
},
|
||||
)
|
||||
.try_reduce(HashMap::new, |mut acc, ws| {
|
||||
for (k, v) in ws {
|
||||
.reduce(
|
||||
|| Ok(HashMap::new()),
|
||||
|acc, ws| {
|
||||
let mut acc = acc?;
|
||||
for (k, v) in ws? {
|
||||
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
||||
}
|
||||
Ok(acc)
|
||||
})?;
|
||||
|
||||
Ok(words)
|
||||
},
|
||||
)
|
||||
})
|
||||
.try_fold(
|
||||
HashMap::new(),
|
||||
|
||||
@@ -6,12 +6,44 @@ use rayon::iter::IterBridge;
|
||||
use rayon::prelude::*;
|
||||
use rayon_cond::CondIterator;
|
||||
|
||||
/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
|
||||
pub fn get_parallelism() -> bool {
|
||||
match std::env::var("TOKENIZERS_PARALLELISM") {
|
||||
Ok(mut v) => {
|
||||
v.make_ascii_lowercase();
|
||||
match v.as_ref() {
|
||||
"" | "off" | "false" | "f" | "no" | "n" | "0" => false,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
Err(_) => true, // If we couldn't get the variable, we use the default
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the value for `TOKENIZERS_PARALLELISM` for the current process
|
||||
pub fn set_parallelism(val: bool) {
|
||||
std::env::set_var("TOKENIZERS_PARALLELISM", if val { "true" } else { "false" })
|
||||
}
|
||||
|
||||
/// Allows to convert into an iterator that can be executed either parallelly or serially.
|
||||
///
|
||||
/// The choice is made according to the currently set `TOKENIZERS_PARALLELISM` environment variable.
|
||||
/// This variable can have one of the following values
|
||||
/// - False => "" (empty value), "false", "f", "off", "no", "n", "0"
|
||||
/// - True => Any other value
|
||||
///
|
||||
pub trait MaybeParallelIterator<P, S>
|
||||
where
|
||||
P: ParallelIterator,
|
||||
S: Iterator<Item = P::Item>,
|
||||
{
|
||||
/// Convert ourself in a CondIterator, that will be executed either in parallel or serially,
|
||||
/// based solely on the `TOKENIZERS_PARALLELISM` environment variable
|
||||
fn into_maybe_par_iter(self) -> CondIterator<P, S>;
|
||||
/// Convert ourself in a CondIterator, that will be executed either in parallel or serially,
|
||||
/// based on both the `TOKENIZERS_PARALLELISM` environment variable and the provided bool.
|
||||
/// Both must be true to run with parallelism activated.
|
||||
fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S>;
|
||||
}
|
||||
|
||||
impl<P, S, I> MaybeParallelIterator<P, S> for I
|
||||
@@ -21,14 +53,20 @@ where
|
||||
S: Iterator<Item = P::Item>,
|
||||
{
|
||||
fn into_maybe_par_iter(self) -> CondIterator<P, S> {
|
||||
// TODO: Define parallelism using std::env
|
||||
// Maybe also add another method that takes a bool to limit parallelism when there are
|
||||
// enough elements to process
|
||||
let parallelism = true;
|
||||
CondIterator::new(self, parallelism)
|
||||
CondIterator::new(self, get_parallelism())
|
||||
}
|
||||
|
||||
fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S> {
|
||||
if cond {
|
||||
self.into_maybe_par_iter()
|
||||
} else {
|
||||
CondIterator::from_serial(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Shared reference version of MaybeParallelIterator, works the same but returns an iterator
|
||||
/// over references, does not consume self
|
||||
pub trait MaybeParallelRefIterator<'data, P, S>
|
||||
where
|
||||
P: ParallelIterator,
|
||||
@@ -36,6 +74,7 @@ where
|
||||
P::Item: 'data,
|
||||
{
|
||||
fn maybe_par_iter(&'data self) -> CondIterator<P, S>;
|
||||
fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S>;
|
||||
}
|
||||
|
||||
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I
|
||||
@@ -48,8 +87,14 @@ where
|
||||
fn maybe_par_iter(&'data self) -> CondIterator<P, S> {
|
||||
self.into_maybe_par_iter()
|
||||
}
|
||||
|
||||
fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S> {
|
||||
self.into_maybe_par_iter_cond(cond)
|
||||
}
|
||||
}
|
||||
|
||||
/// Exclusive reference version of MaybeParallelIterator, works the same but returns an iterator
|
||||
/// over mutable references, does not consume self
|
||||
pub trait MaybeParallelRefMutIterator<'data, P, S>
|
||||
where
|
||||
P: ParallelIterator,
|
||||
@@ -57,6 +102,7 @@ where
|
||||
P::Item: 'data,
|
||||
{
|
||||
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S>;
|
||||
fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S>;
|
||||
}
|
||||
|
||||
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I
|
||||
@@ -69,14 +115,20 @@ where
|
||||
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S> {
|
||||
self.into_maybe_par_iter()
|
||||
}
|
||||
|
||||
fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S> {
|
||||
self.into_maybe_par_iter_cond(cond)
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts any serial iterator into a CondIterator, that can either run parallelly or serially.
|
||||
pub trait MaybeParallelBridge<T, S>
|
||||
where
|
||||
S: Iterator<Item = T> + Send,
|
||||
T: Send,
|
||||
{
|
||||
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S>;
|
||||
fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S>;
|
||||
}
|
||||
|
||||
impl<T, S> MaybeParallelBridge<T, S> for S
|
||||
@@ -86,14 +138,21 @@ where
|
||||
{
|
||||
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S> {
|
||||
let iter = CondIterator::from_serial(self);
|
||||
let parallelism = true;
|
||||
|
||||
if parallelism {
|
||||
if get_parallelism() {
|
||||
CondIterator::from_parallel(iter.into_parallel().right().unwrap())
|
||||
} else {
|
||||
iter
|
||||
}
|
||||
}
|
||||
|
||||
fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S> {
|
||||
if cond {
|
||||
self.maybe_par_bridge()
|
||||
} else {
|
||||
CondIterator::from_serial(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -103,19 +162,11 @@ mod tests {
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_maybe_parallel_iterator() {
|
||||
let mut v = vec![1, 2, 3, 4, 5, 6];
|
||||
let mut v = vec![1u32, 2, 3, 4, 5, 6];
|
||||
|
||||
let iter = v.par_iter();
|
||||
let iter = (&mut v).into_maybe_par_iter();
|
||||
let iter = v.maybe_par_iter();
|
||||
let iter = v.iter().maybe_par_bridge();
|
||||
let iter = v.maybe_par_iter_mut().for_each(|item| {
|
||||
*item *= 2;
|
||||
println!("{}", item)
|
||||
});
|
||||
let iter = (&mut v).maybe_par_iter_mut();
|
||||
let iter = v.into_iter().par_bridge();
|
||||
|
||||
panic!();
|
||||
assert_eq!(v.maybe_par_iter().sum::<u32>(), 21);
|
||||
assert_eq!(v.maybe_par_iter_mut().map(|v| *v * 2).sum::<u32>(), 42);
|
||||
assert_eq!(v.maybe_par_iter().sum::<u32>(), 42);
|
||||
assert_eq!(v.into_maybe_par_iter().sum::<u32>(), 42);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user