Python - Automatically disable parallelism after fork

This commit is contained in:
Anthony MOI
2020-06-22 20:23:06 -04:00
parent 5d20322319
commit ae743f5dc1
4 changed files with 39 additions and 2 deletions

View File

@@ -643,6 +643,7 @@ dependencies = [
name = "tokenizers-python"
version = "0.8.0-rc3"
dependencies = [
"libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)",

View File

@@ -13,6 +13,7 @@ rayon = "1.3"
typetag = "0.1"
serde = "1.0"
serde_json = "1.0"
libc = "0.2"
[dependencies.pyo3]
version = "0.9.2"

View File

@@ -1,3 +1,5 @@
extern crate tokenizers as tk;
mod decoders;
mod encoding;
mod error;
@@ -13,6 +15,23 @@ mod utils;
use pyo3::prelude::*;
use pyo3::wrap_pymodule;
// For users using multiprocessing in python, it is quite easy to fork the process running
// tokenizers, ending up with a deadlock because we internaly make use of multithreading. So
// we register a callback to be called in the event of a fork so that we can warn the user.
static mut REGISTERED_FORK_CALLBACK: bool = false;
extern "C" fn child_after_fork() {
if !tk::parallelism::is_parallelism_configured() {
println!(
"The current process just got forked. Disabling parallelism to avoid deadlocks..."
);
println!(
"To disable this warning, please explicitly set {}=(true | false)",
tk::parallelism::ENV_VARIABLE
);
tk::parallelism::set_parallelism(false);
}
}
/// Trainers Module
#[pymodule]
fn trainers(_py: Python, m: &PyModule) -> PyResult<()> {
@@ -84,6 +103,15 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
/// Tokenizers Module
#[pymodule]
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
// Register the fork callback
#[cfg(target_os = "linux")]
unsafe {
if !REGISTERED_FORK_CALLBACK {
libc::pthread_atfork(None, None, Some(child_after_fork));
REGISTERED_FORK_CALLBACK = true;
}
}
m.add_class::<tokenizer::Tokenizer>()?;
m.add_class::<tokenizer::AddedToken>()?;
m.add_class::<encoding::Encoding>()?;

View File

@@ -6,9 +6,16 @@ use rayon::iter::IterBridge;
use rayon::prelude::*;
use rayon_cond::CondIterator;
pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM";
/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set
pub fn is_parallelism_configured() -> bool {
std::env::var(ENV_VARIABLE).is_ok()
}
/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
pub fn get_parallelism() -> bool {
match std::env::var("TOKENIZERS_PARALLELISM") {
match std::env::var(ENV_VARIABLE) {
Ok(mut v) => {
v.make_ascii_lowercase();
match v.as_ref() {
@@ -22,7 +29,7 @@ pub fn get_parallelism() -> bool {
/// Set the value for `TOKENIZERS_PARALLELISM` for the current process
pub fn set_parallelism(val: bool) {
std::env::set_var("TOKENIZERS_PARALLELISM", if val { "true" } else { "false" })
std::env::set_var(ENV_VARIABLE, if val { "true" } else { "false" })
}
/// Allows to convert into an iterator that can be executed either parallelly or serially.