mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Unsound call of set_var
(#1664)
* refactor: lift cloning to caller * refactor: do not elide lifetimes as in Rust 2018 * fix: unsound use of env::set_var, was breaking stdlib change to make unsafe It is generally not safe to set env variables. The correct way to set a config value that needs to be overridden is to hold a copy internal to the library and only read from the environment.
This commit is contained in:
@ -553,7 +553,7 @@ impl tk::tokenizer::Normalizer for CustomNormalizer {
|
|||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let normalized = PyNormalizedStringRefMut::new(normalized);
|
let normalized = PyNormalizedStringRefMut::new(normalized);
|
||||||
let py_normalized = self.inner.bind(py);
|
let py_normalized = self.inner.bind(py);
|
||||||
py_normalized.call_method("normalize", (normalized.get(),), None)?;
|
py_normalized.call_method("normalize", (normalized.get().clone(),), None)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -634,7 +634,7 @@ impl tk::tokenizer::PreTokenizer for CustomPreTokenizer {
|
|||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let pretok = PyPreTokenizedStringRefMut::new(sentence);
|
let pretok = PyPreTokenizedStringRefMut::new(sentence);
|
||||||
let py_pretok = self.inner.bind(py);
|
let py_pretok = self.inner.bind(py);
|
||||||
py_pretok.call_method("pre_tokenize", (pretok.get(),), None)?;
|
py_pretok.call_method("pre_tokenize", (pretok.get().clone(),), None)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -18,11 +18,11 @@ pub trait DestroyPtr {
|
|||||||
fn destroy(&mut self);
|
fn destroy(&mut self);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RefMutGuard<'r, T: DestroyPtr + Clone> {
|
pub struct RefMutGuard<'r, T: DestroyPtr> {
|
||||||
content: T,
|
content: T,
|
||||||
r: PhantomData<&'r mut T>,
|
r: PhantomData<&'r mut T>,
|
||||||
}
|
}
|
||||||
impl<T: DestroyPtr + Clone> RefMutGuard<'_, T> {
|
impl<T: DestroyPtr> RefMutGuard<'_, T> {
|
||||||
pub fn new(content: T) -> Self {
|
pub fn new(content: T) -> Self {
|
||||||
Self {
|
Self {
|
||||||
content,
|
content,
|
||||||
@ -30,12 +30,12 @@ impl<T: DestroyPtr + Clone> RefMutGuard<'_, T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self) -> T {
|
pub fn get(&self) -> &T {
|
||||||
self.content.clone()
|
&self.content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: DestroyPtr + Clone> Drop for RefMutGuard<'_, T> {
|
impl<T: DestroyPtr> Drop for RefMutGuard<'_, T> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.content.destroy()
|
self.content.destroy()
|
||||||
}
|
}
|
||||||
|
@ -396,7 +396,7 @@ impl DestroyPtr for PyNormalizedStringRefMut {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PyNormalizedStringRefMut {
|
impl PyNormalizedStringRefMut {
|
||||||
pub fn new(normalized: &mut NormalizedString) -> RefMutGuard<Self> {
|
pub fn new(normalized: &mut NormalizedString) -> RefMutGuard<'_, Self> {
|
||||||
RefMutGuard::new(Self {
|
RefMutGuard::new(Self {
|
||||||
inner: RefMutContainer::new(normalized),
|
inner: RefMutContainer::new(normalized),
|
||||||
})
|
})
|
||||||
|
@ -39,7 +39,7 @@ fn normalize(pretok: &mut PreTokenizedString, func: &Bound<'_, PyAny>) -> PyResu
|
|||||||
} else {
|
} else {
|
||||||
ToPyResult(pretok.normalize(|normalized| {
|
ToPyResult(pretok.normalize(|normalized| {
|
||||||
let norm = PyNormalizedStringRefMut::new(normalized);
|
let norm = PyNormalizedStringRefMut::new(normalized);
|
||||||
func.call((norm.get(),), None)?;
|
func.call((norm.get().clone(),), None)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}))
|
}))
|
||||||
.into()
|
.into()
|
||||||
@ -272,7 +272,7 @@ impl DestroyPtr for PyPreTokenizedStringRefMut {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PyPreTokenizedStringRefMut {
|
impl PyPreTokenizedStringRefMut {
|
||||||
pub fn new(pretok: &mut tk::PreTokenizedString) -> RefMutGuard<Self> {
|
pub fn new(pretok: &mut tk::PreTokenizedString) -> RefMutGuard<'_, Self> {
|
||||||
// SAFETY: This is safe because we return a RefMutGuard here.
|
// SAFETY: This is safe because we return a RefMutGuard here.
|
||||||
// The compiler will make sure the &mut stays valid as necessary.
|
// The compiler will make sure the &mut stays valid as necessary.
|
||||||
RefMutGuard::new(Self {
|
RefMutGuard::new(Self {
|
||||||
|
@ -6,6 +6,7 @@ use rayon::iter::IterBridge;
|
|||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use rayon_cond::CondIterator;
|
use rayon_cond::CondIterator;
|
||||||
use std::sync::atomic::AtomicBool;
|
use std::sync::atomic::AtomicBool;
|
||||||
|
use std::sync::atomic::AtomicU8;
|
||||||
use std::sync::atomic::Ordering;
|
use std::sync::atomic::Ordering;
|
||||||
|
|
||||||
// Re-export rayon current_num_threads
|
// Re-export rayon current_num_threads
|
||||||
@ -14,10 +15,11 @@ pub use rayon::current_num_threads;
|
|||||||
pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM";
|
pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM";
|
||||||
|
|
||||||
static USED_PARALLELISM: AtomicBool = AtomicBool::new(false);
|
static USED_PARALLELISM: AtomicBool = AtomicBool::new(false);
|
||||||
|
static PARALLELISM: AtomicU8 = AtomicU8::new(0);
|
||||||
|
|
||||||
/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set
|
/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set
|
||||||
pub fn is_parallelism_configured() -> bool {
|
pub fn is_parallelism_configured() -> bool {
|
||||||
std::env::var(ENV_VARIABLE).is_ok()
|
std::env::var(ENV_VARIABLE).is_ok() || get_override_parallelism().is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if at some point we used a parallel iterator
|
/// Check if at some point we used a parallel iterator
|
||||||
@ -25,8 +27,18 @@ pub fn has_parallelism_been_used() -> bool {
|
|||||||
USED_PARALLELISM.load(Ordering::SeqCst)
|
USED_PARALLELISM.load(Ordering::SeqCst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get internally set parallelism
|
||||||
|
fn get_override_parallelism() -> Option<bool> {
|
||||||
|
match PARALLELISM.load(Ordering::SeqCst) {
|
||||||
|
0 => None,
|
||||||
|
1 => Some(false),
|
||||||
|
2 => Some(true),
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
|
/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
|
||||||
pub fn get_parallelism() -> bool {
|
fn get_env_parallelism() -> bool {
|
||||||
match std::env::var(ENV_VARIABLE) {
|
match std::env::var(ENV_VARIABLE) {
|
||||||
Ok(mut v) => {
|
Ok(mut v) => {
|
||||||
v.make_ascii_lowercase();
|
v.make_ascii_lowercase();
|
||||||
@ -36,9 +48,17 @@ pub fn get_parallelism() -> bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_parallelism() -> bool {
|
||||||
|
if let Some(parallel) = get_override_parallelism() {
|
||||||
|
parallel
|
||||||
|
} else {
|
||||||
|
get_env_parallelism()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the value for `TOKENIZERS_PARALLELISM` for the current process
|
/// Set the value for `TOKENIZERS_PARALLELISM` for the current process
|
||||||
pub fn set_parallelism(val: bool) {
|
pub fn set_parallelism(val: bool) {
|
||||||
std::env::set_var(ENV_VARIABLE, if val { "true" } else { "false" })
|
PARALLELISM.store(if val { 2 } else { 1 }, Ordering::SeqCst);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allows to convert into an iterator that can be executed either parallelly or serially.
|
/// Allows to convert into an iterator that can be executed either parallelly or serially.
|
||||||
|
Reference in New Issue
Block a user