Improve parallelism tracking and warning

This commit is contained in:
Anthony MOI
2020-07-06 13:05:14 -04:00
parent b91deeaa3d
commit 8bf482cecc
2 changed files with 23 additions and 6 deletions

View File

@@ -20,15 +20,19 @@ use pyo3::wrap_pymodule;
// we register a callback to be called in the event of a fork so that we can warn the user. // we register a callback to be called in the event of a fork so that we can warn the user.
static mut REGISTERED_FORK_CALLBACK: bool = false; static mut REGISTERED_FORK_CALLBACK: bool = false;
extern "C" fn child_after_fork() { extern "C" fn child_after_fork() {
if !tk::parallelism::is_parallelism_configured() { use tk::parallelism::*;
if has_parallelism_been_used() && !is_parallelism_configured() {
println!( println!(
"The current process just got forked. Disabling parallelism to avoid deadlocks..." "huggingface/tokenizers: The current process just got forked, after parallelism has \
already been used. Disabling parallelism to avoid deadlocks..."
); );
println!("To disable this warning, you can either:");
println!( println!(
"To disable this warning, please explicitly set {}=(true | false)", "\t- Avoid using `tokenizers` before the fork if possible\n\
tk::parallelism::ENV_VARIABLE \t- Explicitly set the environment variable {}=(true | false)",
ENV_VARIABLE
); );
tk::parallelism::set_parallelism(false); set_parallelism(false);
} }
} }

View File

@@ -8,11 +8,19 @@ use rayon_cond::CondIterator;
pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM"; pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM";
// Reading/Writing this variable should always happen on the main thread
static mut USED_PARALLELISM: bool = false;
/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set /// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set
pub fn is_parallelism_configured() -> bool { pub fn is_parallelism_configured() -> bool {
std::env::var(ENV_VARIABLE).is_ok() std::env::var(ENV_VARIABLE).is_ok()
} }
/// Check if at some point we used a parallel iterator
pub fn has_parallelism_been_used() -> bool {
unsafe { USED_PARALLELISM }
}
/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable /// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
pub fn get_parallelism() -> bool { pub fn get_parallelism() -> bool {
match std::env::var(ENV_VARIABLE) { match std::env::var(ENV_VARIABLE) {
@@ -60,7 +68,11 @@ where
S: Iterator<Item = P::Item>, S: Iterator<Item = P::Item>,
{ {
fn into_maybe_par_iter(self) -> CondIterator<P, S> { fn into_maybe_par_iter(self) -> CondIterator<P, S> {
CondIterator::new(self, get_parallelism()) let parallelism = get_parallelism();
if parallelism {
unsafe { USED_PARALLELISM = true };
}
CondIterator::new(self, parallelism)
} }
fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S> { fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S> {
@@ -147,6 +159,7 @@ where
let iter = CondIterator::from_serial(self); let iter = CondIterator::from_serial(self);
if get_parallelism() { if get_parallelism() {
unsafe { USED_PARALLELISM = true };
CondIterator::from_parallel(iter.into_parallel().right().unwrap()) CondIterator::from_parallel(iter.into_parallel().right().unwrap())
} else { } else {
iter iter