Python - Automatically disable parallelism after fork

This commit is contained in:
Anthony MOI
2020-06-22 20:23:06 -04:00
parent 5d20322319
commit ae743f5dc1
4 changed files with 39 additions and 2 deletions

View File

@@ -643,6 +643,7 @@ dependencies = [
name = "tokenizers-python"
version = "0.8.0-rc3"
dependencies = [
"libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)",
"pyo3 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)",

View File

@@ -13,6 +13,7 @@ rayon = "1.3"
typetag = "0.1"
serde = "1.0"
serde_json = "1.0"
libc = "0.2"
[dependencies.pyo3]
version = "0.9.2"

View File

@@ -1,3 +1,5 @@
extern crate tokenizers as tk;
mod decoders;
mod encoding;
mod error;
@@ -13,6 +15,23 @@ mod utils;
use pyo3::prelude::*;
use pyo3::wrap_pymodule;
// For users using multiprocessing in python, it is quite easy to fork the process running
// tokenizers, ending up with a deadlock because we internaly make use of multithreading. So
// we register a callback to be called in the event of a fork so that we can warn the user.
static mut REGISTERED_FORK_CALLBACK: bool = false;
extern "C" fn child_after_fork() {
if !tk::parallelism::is_parallelism_configured() {
println!(
"The current process just got forked. Disabling parallelism to avoid deadlocks..."
);
println!(
"To disable this warning, please explicitly set {}=(true | false)",
tk::parallelism::ENV_VARIABLE
);
tk::parallelism::set_parallelism(false);
}
}
/// Trainers Module
#[pymodule]
fn trainers(_py: Python, m: &PyModule) -> PyResult<()> {
@@ -84,6 +103,15 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
/// Tokenizers Module
#[pymodule]
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
// Register the fork callback
#[cfg(target_os = "linux")]
unsafe {
if !REGISTERED_FORK_CALLBACK {
libc::pthread_atfork(None, None, Some(child_after_fork));
REGISTERED_FORK_CALLBACK = true;
}
}
m.add_class::<tokenizer::Tokenizer>()?;
m.add_class::<tokenizer::AddedToken>()?;
m.add_class::<encoding::Encoding>()?;