mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-03 11:18:29 +00:00
Python - Automatically disable parallelism after fork
This commit is contained in:
1
bindings/python/Cargo.lock
generated
1
bindings/python/Cargo.lock
generated
@@ -643,6 +643,7 @@ dependencies = [
|
||||
name = "tokenizers-python"
|
||||
version = "0.8.0-rc3"
|
||||
dependencies = [
|
||||
"libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"pyo3 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
|
||||
@@ -13,6 +13,7 @@ rayon = "1.3"
|
||||
typetag = "0.1"
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
libc = "0.2"
|
||||
|
||||
[dependencies.pyo3]
|
||||
version = "0.9.2"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
mod decoders;
|
||||
mod encoding;
|
||||
mod error;
|
||||
@@ -13,6 +15,23 @@ mod utils;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pymodule;
|
||||
|
||||
// For users using multiprocessing in python, it is quite easy to fork the process running
|
||||
// tokenizers, ending up with a deadlock because we internaly make use of multithreading. So
|
||||
// we register a callback to be called in the event of a fork so that we can warn the user.
|
||||
static mut REGISTERED_FORK_CALLBACK: bool = false;
|
||||
extern "C" fn child_after_fork() {
|
||||
if !tk::parallelism::is_parallelism_configured() {
|
||||
println!(
|
||||
"The current process just got forked. Disabling parallelism to avoid deadlocks..."
|
||||
);
|
||||
println!(
|
||||
"To disable this warning, please explicitly set {}=(true | false)",
|
||||
tk::parallelism::ENV_VARIABLE
|
||||
);
|
||||
tk::parallelism::set_parallelism(false);
|
||||
}
|
||||
}
|
||||
|
||||
/// Trainers Module
|
||||
#[pymodule]
|
||||
fn trainers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
@@ -84,6 +103,15 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
/// Tokenizers Module
|
||||
#[pymodule]
|
||||
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
// Register the fork callback
|
||||
#[cfg(target_os = "linux")]
|
||||
unsafe {
|
||||
if !REGISTERED_FORK_CALLBACK {
|
||||
libc::pthread_atfork(None, None, Some(child_after_fork));
|
||||
REGISTERED_FORK_CALLBACK = true;
|
||||
}
|
||||
}
|
||||
|
||||
m.add_class::<tokenizer::Tokenizer>()?;
|
||||
m.add_class::<tokenizer::AddedToken>()?;
|
||||
m.add_class::<encoding::Encoding>()?;
|
||||
|
||||
@@ -6,9 +6,16 @@ use rayon::iter::IterBridge;
|
||||
use rayon::prelude::*;
|
||||
use rayon_cond::CondIterator;
|
||||
|
||||
pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM";
|
||||
|
||||
/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set
|
||||
pub fn is_parallelism_configured() -> bool {
|
||||
std::env::var(ENV_VARIABLE).is_ok()
|
||||
}
|
||||
|
||||
/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
|
||||
pub fn get_parallelism() -> bool {
|
||||
match std::env::var("TOKENIZERS_PARALLELISM") {
|
||||
match std::env::var(ENV_VARIABLE) {
|
||||
Ok(mut v) => {
|
||||
v.make_ascii_lowercase();
|
||||
match v.as_ref() {
|
||||
@@ -22,7 +29,7 @@ pub fn get_parallelism() -> bool {
|
||||
|
||||
/// Set the value for `TOKENIZERS_PARALLELISM` for the current process
|
||||
pub fn set_parallelism(val: bool) {
|
||||
std::env::set_var("TOKENIZERS_PARALLELISM", if val { "true" } else { "false" })
|
||||
std::env::set_var(ENV_VARIABLE, if val { "true" } else { "false" })
|
||||
}
|
||||
|
||||
/// Allows to convert into an iterator that can be executed either parallelly or serially.
|
||||
|
||||
Reference in New Issue
Block a user