Make USED_PARALLELISM atomic (#1532)

This commit is contained in:
nathaniel-daniel
2024-06-06 04:02:26 -07:00
committed by GitHub
parent 25aee8b88c
commit bfefcf676d

View File

@ -5,14 +5,15 @@
use rayon::iter::IterBridge; use rayon::iter::IterBridge;
use rayon::prelude::*; use rayon::prelude::*;
use rayon_cond::CondIterator; use rayon_cond::CondIterator;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
// Re-export rayon current_num_threads // Re-export rayon current_num_threads
pub use rayon::current_num_threads; pub use rayon::current_num_threads;
pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM"; pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM";
// Reading/Writing this variable should always happen on the main thread static USED_PARALLELISM: AtomicBool = AtomicBool::new(false);
static mut USED_PARALLELISM: bool = false;
/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set /// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set
pub fn is_parallelism_configured() -> bool { pub fn is_parallelism_configured() -> bool {
@ -21,7 +22,7 @@ pub fn is_parallelism_configured() -> bool {
/// Check if at some point we used a parallel iterator /// Check if at some point we used a parallel iterator
pub fn has_parallelism_been_used() -> bool { pub fn has_parallelism_been_used() -> bool {
unsafe { USED_PARALLELISM } USED_PARALLELISM.load(Ordering::SeqCst)
} }
/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable /// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
@ -70,7 +71,7 @@ where
fn into_maybe_par_iter(self) -> CondIterator<P, S> { fn into_maybe_par_iter(self) -> CondIterator<P, S> {
let parallelism = get_parallelism(); let parallelism = get_parallelism();
if parallelism { if parallelism {
unsafe { USED_PARALLELISM = true }; USED_PARALLELISM.store(true, Ordering::SeqCst);
} }
CondIterator::new(self, parallelism) CondIterator::new(self, parallelism)
} }
@ -159,7 +160,7 @@ where
let iter = CondIterator::from_serial(self); let iter = CondIterator::from_serial(self);
if get_parallelism() { if get_parallelism() {
unsafe { USED_PARALLELISM = true }; USED_PARALLELISM.store(true, Ordering::SeqCst);
CondIterator::from_parallel(iter.into_parallel().right().unwrap()) CondIterator::from_parallel(iter.into_parallel().right().unwrap())
} else { } else {
iter iter