From ae743f5dc15d0440ce43acff1b7f9a41c1c4579c Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Mon, 22 Jun 2020 20:23:06 -0400 Subject: [PATCH] Python - Automatically disable parallelism after fork --- bindings/python/Cargo.lock | 1 + bindings/python/Cargo.toml | 1 + bindings/python/src/lib.rs | 28 ++++++++++++++++++++++++++++ tokenizers/src/utils/parallelism.rs | 11 +++++++++-- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/bindings/python/Cargo.lock b/bindings/python/Cargo.lock index 0cbdf995..2b912230 100644 --- a/bindings/python/Cargo.lock +++ b/bindings/python/Cargo.lock @@ -643,6 +643,7 @@ dependencies = [ name = "tokenizers-python" version = "0.8.0-rc3" dependencies = [ + "libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)", "pyo3 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)", "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index f49d143d..a7a5796b 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -13,6 +13,7 @@ rayon = "1.3" typetag = "0.1" serde = "1.0" serde_json = "1.0" +libc = "0.2" [dependencies.pyo3] version = "0.9.2" diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index cee23fdd..04033be5 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -1,3 +1,5 @@ +extern crate tokenizers as tk; + mod decoders; mod encoding; mod error; @@ -13,6 +15,23 @@ mod utils; use pyo3::prelude::*; use pyo3::wrap_pymodule; +// For users using multiprocessing in python, it is quite easy to fork the process running +// tokenizers, ending up with a deadlock because we internaly make use of multithreading. So +// we register a callback to be called in the event of a fork so that we can warn the user. +static mut REGISTERED_FORK_CALLBACK: bool = false; +extern "C" fn child_after_fork() { + if !tk::parallelism::is_parallelism_configured() { + println!( + "The current process just got forked. Disabling parallelism to avoid deadlocks..." + ); + println!( + "To disable this warning, please explicitly set {}=(true | false)", + tk::parallelism::ENV_VARIABLE + ); + tk::parallelism::set_parallelism(false); + } +} + /// Trainers Module #[pymodule] fn trainers(_py: Python, m: &PyModule) -> PyResult<()> { @@ -84,6 +103,15 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> { /// Tokenizers Module #[pymodule] fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> { + // Register the fork callback + #[cfg(target_os = "linux")] + unsafe { + if !REGISTERED_FORK_CALLBACK { + libc::pthread_atfork(None, None, Some(child_after_fork)); + REGISTERED_FORK_CALLBACK = true; + } + } + m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/tokenizers/src/utils/parallelism.rs b/tokenizers/src/utils/parallelism.rs index eb54a46d..fe6d9584 100644 --- a/tokenizers/src/utils/parallelism.rs +++ b/tokenizers/src/utils/parallelism.rs @@ -6,9 +6,16 @@ use rayon::iter::IterBridge; use rayon::prelude::*; use rayon_cond::CondIterator; +pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM"; + +/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set +pub fn is_parallelism_configured() -> bool { + std::env::var(ENV_VARIABLE).is_ok() +} + /// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable pub fn get_parallelism() -> bool { - match std::env::var("TOKENIZERS_PARALLELISM") { + match std::env::var(ENV_VARIABLE) { Ok(mut v) => { v.make_ascii_lowercase(); match v.as_ref() { @@ -22,7 +29,7 @@ pub fn get_parallelism() -> bool { /// Set the value for `TOKENIZERS_PARALLELISM` for the current process pub fn set_parallelism(val: bool) { - std::env::set_var("TOKENIZERS_PARALLELISM", if val { "true" } else { "false" }) + std::env::set_var(ENV_VARIABLE, if val { "true" } else { "false" }) } /// Allows to convert into an iterator that can be executed either parallelly or serially.