Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make PyErrState thread-safe #4671

Merged
merged 8 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/4671.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `PyErr` internals thread-safe.
163 changes: 137 additions & 26 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::cell::UnsafeCell;
use std::{
cell::UnsafeCell,
sync::{Mutex, Once},
thread::ThreadId,
};

use crate::{
exceptions::{PyBaseException, PyTypeError},
Expand All @@ -11,15 +15,18 @@ use crate::{
pub(crate) struct PyErrState {
// Safety: can only hand out references when in the "normalized" state. Will never change
// after normalization.
//
// The state is temporarily removed from the PyErr during normalization, to avoid
// concurrent modifications.
normalized: Once,
// Guard against re-entrancy when normalizing the exception state.
normalizing_thread: Mutex<Option<ThreadId>>,
inner: UnsafeCell<Option<PyErrStateInner>>,
}

// The inner value is only accessed through ways that require the gil is held.
// Safety: The inner value is protected by locking to ensure that only the normalized state is
// handed out as a reference.
unsafe impl Send for PyErrState {}
unsafe impl Sync for PyErrState {}
#[cfg(feature = "nightly")]
unsafe impl crate::marker::Ungil for PyErrState {}

impl PyErrState {
pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
Expand Down Expand Up @@ -48,17 +55,22 @@ impl PyErrState {

fn from_inner(inner: PyErrStateInner) -> Self {
Self {
normalized: Once::new(),
normalizing_thread: Mutex::new(None),
inner: UnsafeCell::new(Some(inner)),
}
}

#[inline]
pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
if let Some(PyErrStateInner::Normalized(n)) = unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
return n;
if self.normalized.is_completed() {
match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => return n,
_ => unreachable!(),
}
}

self.make_normalized(py)
Expand All @@ -69,25 +81,47 @@ impl PyErrState {
// This process is safe because:
// - Access is guaranteed not to be concurrent thanks to `Python` GIL token
// - Write happens only once, and then never will change again.
// - State is set to None during the normalization process, so that a second
// concurrent normalization attempt will panic before changing anything.

// FIXME: this needs to be rewritten to deal with free-threaded Python
// see https://github.com/PyO3/pyo3/issues/4584
// Guard against re-entrant normalization, because `Once` does not provide
// re-entrancy guarantees.
if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
assert!(
!(*thread == std::thread::current().id()),
"Re-entrant normalization of PyErrState detected"
);
ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved
}

let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};
// avoid deadlock of `.call_once` with the GIL
py.allow_threads(|| {
self.normalized.call_once(|| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess somehow dropping the GIL somehow allows a race condition to happen where multiple threads try to simultaneously create a module...

Copy link
Member Author

@davidhewitt davidhewitt Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it's a combination with GILOnceCell in the test_declarative_module; we allow racing in GILOnceCell under the condition where switching the GIL, so this module does actually attempt to get created multiple times. I think it's a bug in using GILOnceCell for that test, but this also just makes me dislike this lazy stuff even more...

Copy link
Contributor

@ngoldbaum ngoldbaum Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is just a fundamental issue with GILOnceCell being racey if the code it wraps ever drops the GIL.

EDIT: jinx!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've opened #4676, if I apply that patch on this branch, the problem goes away.

self.normalizing_thread
.lock()
.unwrap()
.replace(std::thread::current().id());

// Safety: no other thread can access the inner value while we are normalizing it.
let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};

let normalized_state =
Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py)));

// Safety: no other thread can access the inner value while we are normalizing it.
unsafe {
*self.inner.get() = Some(normalized_state);
}
})
});

unsafe {
let self_state = &mut *self.inner.get();
*self_state = Some(PyErrStateInner::Normalized(state.normalize(py)));
match self_state {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
}
match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
}
}
}
Expand Down Expand Up @@ -321,3 +355,80 @@ fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
}
}
}

#[cfg(test)]
mod tests {

use crate::{
exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python,
};

#[test]
#[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved
fn test_reentrant_normalization() {
static ERR: GILOnceCell<PyErr> = GILOnceCell::new();

struct RecursiveArgs;

impl PyErrArguments for RecursiveArgs {
fn arguments(self, py: Python<'_>) -> PyObject {
// .value(py) triggers normalization
ERR.get(py)
.expect("is set just below")
.value(py)
.clone()
.into()
}
}

Python::with_gil(|py| {
ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
ERR.get(py).expect("is set just above").value(py);
})
}

#[test]
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
fn test_no_deadlock_thread_switch() {
static ERR: GILOnceCell<PyErr> = GILOnceCell::new();

struct GILSwitchArgs;

impl PyErrArguments for GILSwitchArgs {
fn arguments(self, py: Python<'_>) -> PyObject {
// releasing the GIL potentially allows for other threads to deadlock
// with the normalization going on here
py.allow_threads(|| {
std::thread::sleep(std::time::Duration::from_millis(10));
});
py.None()
}
}

Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());

// Let many threads attempt to read the normalized value at the same time
let handles = (0..10)
.map(|_| {
std::thread::spawn(|| {
Python::with_gil(|py| {
ERR.get(py).expect("is set just above").value(py);
});
})
})
.collect::<Vec<_>>();

for handle in handles {
handle.join().unwrap();
}

// We should never have deadlocked, and should be able to run
// this assertion
Python::with_gil(|py| {
assert!(ERR
.get(py)
.expect("is set above")
.is_instance_of::<PyValueError>(py))
});
}
}
Loading