Skip to content

Commit

Permalink
avoid unnecessary critical sections in PyDict
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoldbaum committed Dec 10, 2024
1 parent 919e58e commit 944b78e
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions src/types/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ enum DictIterImpl {

impl DictIterImpl {
#[inline]
fn next<'py>(
unsafe fn next_unchecked<'py>(
&mut self,
dict: &Bound<'py, PyDict>,
) -> Option<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
Expand Down Expand Up @@ -478,15 +478,15 @@ impl DictIterImpl {
let mut key: *mut ffi::PyObject = std::ptr::null_mut();
let mut value: *mut ffi::PyObject = std::ptr::null_mut();

if unsafe { ffi::PyDict_Next(dict.as_ptr(), ppos, &mut key, &mut value) } != 0 {
if ffi::PyDict_Next(dict.as_ptr(), ppos, &mut key, &mut value) != 0 {
*remaining -= 1;
let py = dict.py();
// Safety:
// - PyDict_Next returns borrowed values
// - we have already checked that `PyDict_Next` succeeded, so we can assume these to be non-null
Some((
unsafe { key.assume_borrowed_unchecked(py) }.to_owned(),
unsafe { value.assume_borrowed_unchecked(py) }.to_owned(),
key.assume_borrowed_unchecked(py).to_owned(),
value.assume_borrowed_unchecked(py).to_owned(),
))
} else {
None
Expand All @@ -512,7 +512,10 @@ impl<'py> Iterator for BoundDictIterator<'py> {

#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.inner.next(&self.dict)
self.inner
.with_critical_section(&self.dict, |inner| unsafe {
inner.next_unchecked(&self.dict)
})
}

#[inline]
Expand All @@ -530,7 +533,7 @@ impl<'py> Iterator for BoundDictIterator<'py> {
{
self.inner.with_critical_section(&self.dict, |inner| {
let mut accum = init;
while let Some(x) = inner.next(&self.dict) {
while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } {
accum = f(accum, x);
}
accum
Expand All @@ -547,7 +550,7 @@ impl<'py> Iterator for BoundDictIterator<'py> {
{
self.inner.with_critical_section(&self.dict, |inner| {
let mut accum = init;
while let Some(x) = inner.next(&self.dict) {
while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } {
accum = f(accum, x)?
}
R::from_output(accum)
Expand All @@ -562,7 +565,7 @@ impl<'py> Iterator for BoundDictIterator<'py> {
F: FnMut(Self::Item) -> bool,
{
self.inner.with_critical_section(&self.dict, |inner| {
while let Some(x) = inner.next(&self.dict) {
while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } {
if !f(x) {
return false;
}
Expand All @@ -579,7 +582,7 @@ impl<'py> Iterator for BoundDictIterator<'py> {
F: FnMut(Self::Item) -> bool,
{
self.inner.with_critical_section(&self.dict, |inner| {
while let Some(x) = inner.next(&self.dict) {
while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } {
if f(x) {
return true;
}
Expand All @@ -596,7 +599,7 @@ impl<'py> Iterator for BoundDictIterator<'py> {
P: FnMut(&Self::Item) -> bool,
{
self.inner.with_critical_section(&self.dict, |inner| {
while let Some(x) = inner.next(&self.dict) {
while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } {
if predicate(&x) {
return Some(x);
}
Expand All @@ -613,7 +616,7 @@ impl<'py> Iterator for BoundDictIterator<'py> {
F: FnMut(Self::Item) -> Option<B>,
{
self.inner.with_critical_section(&self.dict, |inner| {
while let Some(x) = inner.next(&self.dict) {
while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } {
if let found @ Some(_) = f(x) {
return found;
}
Expand All @@ -631,7 +634,7 @@ impl<'py> Iterator for BoundDictIterator<'py> {
{
self.inner.with_critical_section(&self.dict, |inner| {
let mut acc = 0;
while let Some(x) = inner.next(&self.dict) {
while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } {
if predicate(x) {
return Some(acc);
}
Expand Down

0 comments on commit 944b78e

Please sign in to comment.