mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-16 17:18:43 +00:00
Rust - Fix optional parallelism with par_bridge
This commit is contained in:
6
bindings/python/Cargo.lock
generated
6
bindings/python/Cargo.lock
generated
@@ -497,7 +497,7 @@ dependencies = [
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "rayon-cond"
|
name = "rayon-cond"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "git+https://github.com/n1t0/rayon-cond#c56e4f1ded0fcb92eac70e0533703bba3ca2983f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
"itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
@@ -629,7 +629,7 @@ dependencies = [
|
|||||||
"onig 6.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"onig 6.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rayon-cond 0.1.0 (git+https://github.com/n1t0/rayon-cond)",
|
||||||
"regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
"regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)",
|
"regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)",
|
"serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
@@ -793,7 +793,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||||
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||||
"checksum rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "db6ce3297f9c85e16621bb8cca38a06779ffc31bb8184e1be4bed2be4678a098"
|
"checksum rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "db6ce3297f9c85e16621bb8cca38a06779ffc31bb8184e1be4bed2be4678a098"
|
||||||
"checksum rayon-cond 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7"
|
"checksum rayon-cond 0.1.0 (git+https://github.com/n1t0/rayon-cond)" = "<none>"
|
||||||
"checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9"
|
"checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9"
|
||||||
"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84"
|
"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84"
|
||||||
"checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3"
|
"checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3"
|
||||||
|
|||||||
@@ -47,4 +47,4 @@ tokenizer.train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save the files
|
# Save the files
|
||||||
tokenizer.save(args.out, args.name)
|
tokenizer.save_model(args.out, args.name)
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ tokenizer.train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save the files
|
# Save the files
|
||||||
tokenizer.save(args.out, args.name)
|
tokenizer.save_model(args.out, args.name)
|
||||||
|
|
||||||
# Restoring model from learned vocab/merges
|
# Restoring model from learned vocab/merges
|
||||||
tokenizer = ByteLevelBPETokenizer(
|
tokenizer = ByteLevelBPETokenizer(
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ onig = { version = "6.0", default-features = false }
|
|||||||
regex = "1.3"
|
regex = "1.3"
|
||||||
regex-syntax = "0.6"
|
regex-syntax = "0.6"
|
||||||
rayon = "1.3"
|
rayon = "1.3"
|
||||||
rayon-cond = "0.1"
|
rayon-cond = { version = "*", git = "https://github.com/n1t0/rayon-cond" }
|
||||||
serde = { version = "1.0", features = [ "derive" ] }
|
serde = { version = "1.0", features = [ "derive" ] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
typetag = "0.1"
|
typetag = "0.1"
|
||||||
|
|||||||
@@ -379,11 +379,11 @@ impl BpeTrainer {
|
|||||||
h
|
h
|
||||||
});
|
});
|
||||||
*pair_counts.get_mut(&cur_pair).unwrap() += count as i32;
|
*pair_counts.get_mut(&cur_pair).unwrap() += count as i32;
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(p) = &p {
|
if let Some(p) = &p {
|
||||||
p.inc(1);
|
p.inc(1);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
(pair_counts, where_to_update)
|
(pair_counts, where_to_update)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -613,11 +613,8 @@ impl Tokenizer {
|
|||||||
// We read new lines using this API instead of the Lines Iterator
|
// We read new lines using this API instead of the Lines Iterator
|
||||||
// on purpose. We want to keep the `\n` and potential `\r` between each lines
|
// on purpose. We want to keep the `\n` and potential `\r` between each lines
|
||||||
// We use an iterator to be able to chain with par_bridge.
|
// We use an iterator to be able to chain with par_bridge.
|
||||||
use rayon::prelude::*;
|
file.lines_with_ending()
|
||||||
let words = file
|
.maybe_par_bridge()
|
||||||
.lines_with_ending()
|
|
||||||
//.maybe_par_bridge()
|
|
||||||
.par_bridge()
|
|
||||||
.map_with(
|
.map_with(
|
||||||
&progress,
|
&progress,
|
||||||
|progress, line| -> Result<HashMap<String, u32>> {
|
|progress, line| -> Result<HashMap<String, u32>> {
|
||||||
@@ -638,14 +635,16 @@ impl Tokenizer {
|
|||||||
Ok(words)
|
Ok(words)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.try_reduce(HashMap::new, |mut acc, ws| {
|
.reduce(
|
||||||
for (k, v) in ws {
|
|| Ok(HashMap::new()),
|
||||||
|
|acc, ws| {
|
||||||
|
let mut acc = acc?;
|
||||||
|
for (k, v) in ws? {
|
||||||
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
acc.entry(k).and_modify(|c| *c += v).or_insert(v);
|
||||||
}
|
}
|
||||||
Ok(acc)
|
Ok(acc)
|
||||||
})?;
|
},
|
||||||
|
)
|
||||||
Ok(words)
|
|
||||||
})
|
})
|
||||||
.try_fold(
|
.try_fold(
|
||||||
HashMap::new(),
|
HashMap::new(),
|
||||||
|
|||||||
@@ -6,12 +6,44 @@ use rayon::iter::IterBridge;
|
|||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use rayon_cond::CondIterator;
|
use rayon_cond::CondIterator;
|
||||||
|
|
||||||
|
/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
|
||||||
|
pub fn get_parallelism() -> bool {
|
||||||
|
match std::env::var("TOKENIZERS_PARALLELISM") {
|
||||||
|
Ok(mut v) => {
|
||||||
|
v.make_ascii_lowercase();
|
||||||
|
match v.as_ref() {
|
||||||
|
"" | "off" | "false" | "f" | "no" | "n" | "0" => false,
|
||||||
|
_ => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => true, // If we couldn't get the variable, we use the default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the value for `TOKENIZERS_PARALLELISM` for the current process
|
||||||
|
pub fn set_parallelism(val: bool) {
|
||||||
|
std::env::set_var("TOKENIZERS_PARALLELISM", if val { "true" } else { "false" })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allows to convert into an iterator that can be executed either parallelly or serially.
|
||||||
|
///
|
||||||
|
/// The choice is made according to the currently set `TOKENIZERS_PARALLELISM` environment variable.
|
||||||
|
/// This variable can have one of the following values
|
||||||
|
/// - False => "" (empty value), "false", "f", "off", "no", "n", "0"
|
||||||
|
/// - True => Any other value
|
||||||
|
///
|
||||||
pub trait MaybeParallelIterator<P, S>
|
pub trait MaybeParallelIterator<P, S>
|
||||||
where
|
where
|
||||||
P: ParallelIterator,
|
P: ParallelIterator,
|
||||||
S: Iterator<Item = P::Item>,
|
S: Iterator<Item = P::Item>,
|
||||||
{
|
{
|
||||||
|
/// Convert ourself in a CondIterator, that will be executed either in parallel or serially,
|
||||||
|
/// based solely on the `TOKENIZERS_PARALLELISM` environment variable
|
||||||
fn into_maybe_par_iter(self) -> CondIterator<P, S>;
|
fn into_maybe_par_iter(self) -> CondIterator<P, S>;
|
||||||
|
/// Convert ourself in a CondIterator, that will be executed either in parallel or serially,
|
||||||
|
/// based on both the `TOKENIZERS_PARALLELISM` environment variable and the provided bool.
|
||||||
|
/// Both must be true to run with parallelism activated.
|
||||||
|
fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<P, S, I> MaybeParallelIterator<P, S> for I
|
impl<P, S, I> MaybeParallelIterator<P, S> for I
|
||||||
@@ -21,14 +53,20 @@ where
|
|||||||
S: Iterator<Item = P::Item>,
|
S: Iterator<Item = P::Item>,
|
||||||
{
|
{
|
||||||
fn into_maybe_par_iter(self) -> CondIterator<P, S> {
|
fn into_maybe_par_iter(self) -> CondIterator<P, S> {
|
||||||
// TODO: Define parallelism using std::env
|
CondIterator::new(self, get_parallelism())
|
||||||
// Maybe also add another method that takes a bool to limit parallelism when there are
|
}
|
||||||
// enough elements to process
|
|
||||||
let parallelism = true;
|
fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S> {
|
||||||
CondIterator::new(self, parallelism)
|
if cond {
|
||||||
|
self.into_maybe_par_iter()
|
||||||
|
} else {
|
||||||
|
CondIterator::from_serial(self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Shared reference version of MaybeParallelIterator, works the same but returns an iterator
|
||||||
|
/// over references, does not consume self
|
||||||
pub trait MaybeParallelRefIterator<'data, P, S>
|
pub trait MaybeParallelRefIterator<'data, P, S>
|
||||||
where
|
where
|
||||||
P: ParallelIterator,
|
P: ParallelIterator,
|
||||||
@@ -36,6 +74,7 @@ where
|
|||||||
P::Item: 'data,
|
P::Item: 'data,
|
||||||
{
|
{
|
||||||
fn maybe_par_iter(&'data self) -> CondIterator<P, S>;
|
fn maybe_par_iter(&'data self) -> CondIterator<P, S>;
|
||||||
|
fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I
|
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I
|
||||||
@@ -48,8 +87,14 @@ where
|
|||||||
fn maybe_par_iter(&'data self) -> CondIterator<P, S> {
|
fn maybe_par_iter(&'data self) -> CondIterator<P, S> {
|
||||||
self.into_maybe_par_iter()
|
self.into_maybe_par_iter()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S> {
|
||||||
|
self.into_maybe_par_iter_cond(cond)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Exclusive reference version of MaybeParallelIterator, works the same but returns an iterator
|
||||||
|
/// over mutable references, does not consume self
|
||||||
pub trait MaybeParallelRefMutIterator<'data, P, S>
|
pub trait MaybeParallelRefMutIterator<'data, P, S>
|
||||||
where
|
where
|
||||||
P: ParallelIterator,
|
P: ParallelIterator,
|
||||||
@@ -57,6 +102,7 @@ where
|
|||||||
P::Item: 'data,
|
P::Item: 'data,
|
||||||
{
|
{
|
||||||
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S>;
|
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S>;
|
||||||
|
fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I
|
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I
|
||||||
@@ -69,14 +115,20 @@ where
|
|||||||
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S> {
|
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S> {
|
||||||
self.into_maybe_par_iter()
|
self.into_maybe_par_iter()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S> {
|
||||||
|
self.into_maybe_par_iter_cond(cond)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Converts any serial iterator into a CondIterator, that can either run parallelly or serially.
|
||||||
pub trait MaybeParallelBridge<T, S>
|
pub trait MaybeParallelBridge<T, S>
|
||||||
where
|
where
|
||||||
S: Iterator<Item = T> + Send,
|
S: Iterator<Item = T> + Send,
|
||||||
T: Send,
|
T: Send,
|
||||||
{
|
{
|
||||||
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S>;
|
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S>;
|
||||||
|
fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, S> MaybeParallelBridge<T, S> for S
|
impl<T, S> MaybeParallelBridge<T, S> for S
|
||||||
@@ -86,14 +138,21 @@ where
|
|||||||
{
|
{
|
||||||
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S> {
|
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S> {
|
||||||
let iter = CondIterator::from_serial(self);
|
let iter = CondIterator::from_serial(self);
|
||||||
let parallelism = true;
|
|
||||||
|
|
||||||
if parallelism {
|
if get_parallelism() {
|
||||||
CondIterator::from_parallel(iter.into_parallel().right().unwrap())
|
CondIterator::from_parallel(iter.into_parallel().right().unwrap())
|
||||||
} else {
|
} else {
|
||||||
iter
|
iter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S> {
|
||||||
|
if cond {
|
||||||
|
self.maybe_par_bridge()
|
||||||
|
} else {
|
||||||
|
CondIterator::from_serial(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -103,19 +162,11 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
fn test_maybe_parallel_iterator() {
|
fn test_maybe_parallel_iterator() {
|
||||||
let mut v = vec![1, 2, 3, 4, 5, 6];
|
let mut v = vec![1u32, 2, 3, 4, 5, 6];
|
||||||
|
|
||||||
let iter = v.par_iter();
|
assert_eq!(v.maybe_par_iter().sum::<u32>(), 21);
|
||||||
let iter = (&mut v).into_maybe_par_iter();
|
assert_eq!(v.maybe_par_iter_mut().map(|v| *v * 2).sum::<u32>(), 42);
|
||||||
let iter = v.maybe_par_iter();
|
assert_eq!(v.maybe_par_iter().sum::<u32>(), 42);
|
||||||
let iter = v.iter().maybe_par_bridge();
|
assert_eq!(v.into_maybe_par_iter().sum::<u32>(), 42);
|
||||||
let iter = v.maybe_par_iter_mut().for_each(|item| {
|
|
||||||
*item *= 2;
|
|
||||||
println!("{}", item)
|
|
||||||
});
|
|
||||||
let iter = (&mut v).maybe_par_iter_mut();
|
|
||||||
let iter = v.into_iter().par_bridge();
|
|
||||||
|
|
||||||
panic!();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user