Skip to content

Commit

Permalink
Use critical section wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
bschoenmaeckers committed Oct 7, 2024
1 parent f92ddc5 commit ac9fcf9
Showing 1 changed file with 118 additions and 149 deletions.
267 changes: 118 additions & 149 deletions src/types/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,23 +366,22 @@ impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> {
BoundDictIterator::new(self.clone())
}

#[cfg(Py_GIL_DISABLED)]
fn locked_for_each<F>(&self, closure: F) -> PyResult<()>
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
fn locked_for_each<F>(&self, f: F) -> PyResult<()>
where
F: Fn(Bound<'py, PyAny>, Bound<'py, PyAny>) -> PyResult<()>,
{
let mut section = unsafe { std::mem::zeroed() };
unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) };

for (key, value) in self {
if let Err(err) = closure(key, value) {
unsafe { ffi::PyCriticalSection_End(&mut section) };
return Err(err);
}
#[cfg(feature = "nightly")]
{
self.iter().try_for_each(|(key, value)| f(key, value))
}

unsafe { ffi::PyCriticalSection_End(&mut section) };
Ok(())
#[cfg(not(feature = "nightly"))]
{
crate::sync::with_critical_section(self, || {
self.iter().try_for_each(|(key, value)| f(key, value))
})
}
}

fn as_mapping(&self) -> &Bound<'py, PyMapping> {
Expand Down Expand Up @@ -452,29 +451,25 @@ impl<'py> Iterator for BoundDictIterator<'py> {

#[inline]
fn next(&mut self) -> Option<Self::Item> {
match self {
BoundDictIterator::ItemIter { iter, remaining } => {
self.with_critical_section(|iter| match iter {
BoundDictIterator::ItemIter {
iter: ref mut py_iter,
ref mut remaining,
} => {
*remaining = remaining.saturating_sub(1);
iter.next().map(Result::unwrap).map(|tuple| {
py_iter.next().map(Result::unwrap).map(|tuple| {
let tuple = tuple.downcast::<PyTuple>().unwrap();
let key = tuple.get_item(0).unwrap();
let value = tuple.get_item(1).unwrap();
(key, value)
})
}
BoundDictIterator::DictIter {
dict,
ppos,
di_used,
remaining,
ref mut dict,
ref mut ppos,
ref mut di_used,
ref mut remaining,
} => {
#[cfg(Py_GIL_DISABLED)]
let mut section = unsafe { std::mem::zeroed() };
#[cfg(Py_GIL_DISABLED)]
unsafe {
ffi::PyCriticalSection_Begin(&mut section, dict.as_ptr());
};

let ma_used = dict_len(dict);

// These checks are similar to what CPython does.
Expand Down Expand Up @@ -504,10 +499,7 @@ impl<'py> Iterator for BoundDictIterator<'py> {
let mut key: *mut ffi::PyObject = std::ptr::null_mut();
let mut value: *mut ffi::PyObject = std::ptr::null_mut();

let result = if unsafe {
ffi::PyDict_Next(dict.as_ptr(), ppos, &mut key, &mut value)
} != 0
{
if unsafe { ffi::PyDict_Next(dict.as_ptr(), ppos, &mut key, &mut value) } != 0 {
*remaining -= 1;
let py = dict.py();
// Safety:
Expand All @@ -519,16 +511,9 @@ impl<'py> Iterator for BoundDictIterator<'py> {
))
} else {
None
};

#[cfg(Py_GIL_DISABLED)]
unsafe {
ffi::PyCriticalSection_End(&mut section);
}

result
}
}
})
}

#[inline]
Expand All @@ -544,15 +529,13 @@ impl<'py> Iterator for BoundDictIterator<'py> {
Self: Sized,
F: FnMut(B, Self::Item) -> B,
{
let mut section = unsafe { std::mem::zeroed() };
unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) };

let mut accum = init;
for x in &mut self {
accum = f(accum, x);
}
unsafe { ffi::PyCriticalSection_End(&mut section) };
accum
self.with_critical_section(|mut iter| {
let mut accum = init;
for x in &mut iter {
accum = f(accum, x);
}
accum
})
}

#[inline]
Expand All @@ -563,22 +546,13 @@ impl<'py> Iterator for BoundDictIterator<'py> {
F: FnMut(B, Self::Item) -> R,
R: std::ops::Try<Output = B>,
{
let mut section = unsafe { std::mem::zeroed() };
unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) };

let mut accum = init;

for x in &mut self {
match f(accum, x).branch() {
ControlFlow::Continue(a) => accum = a,
ControlFlow::Break(err) => {
unsafe { ffi::PyCriticalSection_End(&mut section) }
return R::from_residual(err);
}
self.with_critical_section(|mut iter| {
let mut accum = init;
for x in &mut iter {
accum = f(accum, x)?
}
}
unsafe { ffi::PyCriticalSection_End(&mut section) };
R::from_output(accum)
R::from_output(accum)
})
}

#[inline]
Expand All @@ -588,22 +562,19 @@ impl<'py> Iterator for BoundDictIterator<'py> {
Self: Sized,
F: FnMut(Self::Item) -> bool,
{
let mut section = unsafe { std::mem::zeroed() };
unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) };

#[inline]
fn check<T>(mut f: impl FnMut(T) -> bool) -> impl FnMut((), T) -> ControlFlow<()> {
move |(), x| {
if f(x) {
ControlFlow::Continue(())
} else {
ControlFlow::Break(())
self.with_critical_section(|iter| {
#[inline]
fn check<T>(mut f: impl FnMut(T) -> bool) -> impl FnMut((), T) -> ControlFlow<()> {
move |(), x| {
if f(x) {
ControlFlow::Continue(())
} else {
ControlFlow::Break(())
}
}
}
}
let result = self.try_fold((), check(f)) == ControlFlow::Continue(());
unsafe { ffi::PyCriticalSection_End(&mut section) };
result
iter.try_fold((), check(f)) == ControlFlow::Continue(())
})
}

#[inline]
Expand All @@ -613,23 +584,20 @@ impl<'py> Iterator for BoundDictIterator<'py> {
Self: Sized,
F: FnMut(Self::Item) -> bool,
{
let mut section = unsafe { std::mem::zeroed() };
unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) };

#[inline]
fn check<T>(mut f: impl FnMut(T) -> bool) -> impl FnMut((), T) -> ControlFlow<()> {
move |(), x| {
if f(x) {
ControlFlow::Break(())
} else {
ControlFlow::Continue(())
self.with_critical_section(|iter| {
#[inline]
fn check<T>(mut f: impl FnMut(T) -> bool) -> impl FnMut((), T) -> ControlFlow<()> {
move |(), x| {
if f(x) {
ControlFlow::Break(())
} else {
ControlFlow::Continue(())
}
}
}
}

let result = self.try_fold((), check(f)) == ControlFlow::Break(());
unsafe { ffi::PyCriticalSection_End(&mut section) };
result
iter.try_fold((), check(f)) == ControlFlow::Break(())
})
}

#[inline]
Expand All @@ -639,26 +607,25 @@ impl<'py> Iterator for BoundDictIterator<'py> {
Self: Sized,
P: FnMut(&Self::Item) -> bool,
{
let mut section = unsafe { std::mem::zeroed() };
unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) };

#[inline]
fn check<T>(mut predicate: impl FnMut(&T) -> bool) -> impl FnMut((), T) -> ControlFlow<T> {
move |(), x| {
if predicate(&x) {
ControlFlow::Break(x)
} else {
ControlFlow::Continue(())
self.with_critical_section(|iter| {
#[inline]
fn check<T>(
mut predicate: impl FnMut(&T) -> bool,
) -> impl FnMut((), T) -> ControlFlow<T> {
move |(), x| {
if predicate(&x) {
ControlFlow::Break(x)
} else {
ControlFlow::Continue(())
}
}
}
}

let result = match self.try_fold((), check(predicate)) {
ControlFlow::Continue(_) => None,
ControlFlow::Break(x) => Some(x),
};
unsafe { ffi::PyCriticalSection_End(&mut section) };
result
match iter.try_fold((), check(predicate)) {
ControlFlow::Continue(_) => None,
ControlFlow::Break(x) => Some(x),
}
})
}

#[inline]
Expand All @@ -668,23 +635,22 @@ impl<'py> Iterator for BoundDictIterator<'py> {
Self: Sized,
F: FnMut(Self::Item) -> Option<B>,
{
let mut section = unsafe { std::mem::zeroed() };
unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) };

#[inline]
fn check<T, B>(mut f: impl FnMut(T) -> Option<B>) -> impl FnMut((), T) -> ControlFlow<B> {
move |(), x| match f(x) {
Some(x) => ControlFlow::Break(x),
None => ControlFlow::Continue(()),
self.with_critical_section(|iter| {
#[inline]
fn check<T, B>(
mut f: impl FnMut(T) -> Option<B>,
) -> impl FnMut((), T) -> ControlFlow<B> {
move |(), x| match f(x) {
Some(x) => ControlFlow::Break(x),
None => ControlFlow::Continue(()),
}
}
}

let result = match self.try_fold((), check(f)) {
ControlFlow::Continue(_) => None,
ControlFlow::Break(x) => Some(x),
};
unsafe { ffi::PyCriticalSection_End(&mut section) };
result
match iter.try_fold((), check(f)) {
ControlFlow::Continue(_) => None,
ControlFlow::Break(x) => Some(x),
}
})
}

#[inline]
Expand All @@ -694,32 +660,28 @@ impl<'py> Iterator for BoundDictIterator<'py> {
Self: Sized,
P: FnMut(Self::Item) -> bool,
{
let mut section = unsafe { std::mem::zeroed() };
unsafe { ffi::PyCriticalSection_Begin(&mut section, self.as_ptr()) };

#[inline]
fn check<'a, T>(
mut predicate: impl FnMut(T) -> bool + 'a,
acc: &'a mut usize,
) -> impl FnMut((), T) -> ControlFlow<usize, ()> + 'a {
move |_, x| {
if predicate(x) {
ControlFlow::Break(*acc)
} else {
*acc += 1;
ControlFlow::Continue(())
self.with_critical_section(|iter| {
#[inline]
fn check<'a, T>(
mut predicate: impl FnMut(T) -> bool + 'a,
acc: &'a mut usize,
) -> impl FnMut((), T) -> ControlFlow<usize, ()> + 'a {
move |_, x| {
if predicate(x) {
ControlFlow::Break(*acc)
} else {
*acc += 1;
ControlFlow::Continue(())
}
}
}
}

let mut acc = 0;
let result = match self.try_fold((), check(predicate, &mut acc)) {
ControlFlow::Continue(_) => None,
ControlFlow::Break(x) => Some(x),
};

unsafe { ffi::PyCriticalSection_End(&mut section) };
result
let mut acc = 0;
match iter.try_fold((), check(predicate, &mut acc)) {
ControlFlow::Continue(_) => None,
ControlFlow::Break(x) => Some(x),
}
})
}
}

Expand Down Expand Up @@ -751,11 +713,18 @@ impl<'py> BoundDictIterator<'py> {
}

#[inline]
#[cfg(Py_GIL_DISABLED)]
fn as_ptr(&self) -> *mut ffi::PyObject {
fn with_critical_section<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut Self) -> R,
{
match self {
BoundDictIterator::ItemIter { ref iter, .. } => iter.as_ptr(),
BoundDictIterator::DictIter { ref dict, .. } => dict.as_ptr(),
BoundDictIterator::ItemIter { .. } => f(self),
#[cfg(not(Py_GIL_DISABLED))]
BoundDictIterator::DictIter { .. } => f(self),
#[cfg(Py_GIL_DISABLED)]
BoundDictIterator::DictIter { ref dict, .. } => {
crate::sync::with_critical_section(dict.clone().as_ref(), || f(self))
}
}
}
}
Expand Down

0 comments on commit ac9fcf9

Please sign in to comment.