Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize nth and nth_back for BoundListIterator #4810

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/4810.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Optimizes `nth` and `nth_back` for `BoundListIterator`
4 changes: 2 additions & 2 deletions noxfile.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes in this file seem unrelated, maybe send in a separate PR so we can get these ruff fixes merged quickly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah - I noticed that ruff has released 0.9.0 2 hours ago and the CI didn't pin the ruff version, hence I hit this error which doesn't exist when the latest MR is merged.

I'll file another MR to quickly format this file

Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_emscripten(session: nox.Session):
session,
"bash",
"-c",
f"source {info.builddir/'emsdk/emsdk_env.sh'} && cargo test",
f"source {info.builddir / 'emsdk/emsdk_env.sh'} && cargo test",
)


Expand Down Expand Up @@ -951,7 +951,7 @@ def set(
f"""\
implementation={implementation}
version={version}
build_flags={','.join(build_flags)}
build_flags={",".join(build_flags)}
suppress_build_script_link_lines=true
"""
)
Expand Down
30 changes: 29 additions & 1 deletion pyo3-benches/benches/bench_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,33 @@ fn list_get_item(b: &mut Bencher<'_>) {
});
}

#[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))]
fn list_nth(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
const LEN: usize = 50;
let list = PyList::new_bound(py, 0..LEN);
let mut sum = 0;
b.iter(|| {
for i in 0..LEN {
sum += list.iter().nth(i).unwrap().extract::<usize>().unwrap();
}
});
});
}

fn list_nth_back(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
const LEN: usize = 50;
let list = PyList::new_bound(py, 0..LEN);
let mut sum = 0;
b.iter(|| {
for i in 0..LEN {
sum += list.iter().nth_back(i).unwrap().extract::<usize>().unwrap();
}
});
});
}

#[cfg(not(Py_LIMITED_API))]
fn list_get_item_unchecked(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
const LEN: usize = 50_000;
Expand All @@ -66,6 +92,8 @@ fn sequence_from_list(b: &mut Bencher<'_>) {
fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("iter_list", iter_list);
c.bench_function("list_new", list_new);
c.bench_function("list_nth", list_nth);
c.bench_function("list_nth_back", list_nth_back);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally missed this when I added the new free-threaded implementations for folding operations in list. I'll add benchmarks for those in a separate PR.

c.bench_function("list_get_item", list_get_item);
#[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))]
c.bench_function("list_get_item_unchecked", list_get_item_unchecked);
Expand Down
190 changes: 188 additions & 2 deletions src/types/list.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::iter::FusedIterator;

use crate::err::{self, PyResult};
use crate::ffi::{self, Py_ssize_t};
use crate::ffi_ptr_ext::FfiPtrExt;
Expand All @@ -8,6 +6,7 @@ use crate::types::{PySequence, PyTuple};
use crate::{
Borrowed, Bound, BoundObject, IntoPyObject, IntoPyObjectExt, PyAny, PyErr, PyObject, Python,
};
use std::iter::FusedIterator;

use crate::types::any::PyAnyMethods;
use crate::types::sequence::PySequenceMethods;
Expand Down Expand Up @@ -547,6 +546,31 @@ impl<'py> BoundListIterator<'py> {
}
}

/// # Safety
///
/// On the free-threaded build, caller must verify they have exclusive
/// access to the list by holding a lock or by holding the innermost
/// critical section on the list.
#[inline]
#[cfg(not(Py_LIMITED_API))]
#[deny(unsafe_op_in_unsafe_fn)]
unsafe fn nth_unchecked(
index: &mut Index,
length: &mut Length,
list: &Bound<'py, PyList>,
n: usize,
) -> Option<Bound<'py, PyAny>> {
let length = length.0.min(list.len());
let target_index = index.0 + n;
if index.0 + n < length {
let item = unsafe { list.get_item_unchecked(target_index) };
index.0 = target_index + 1;
Some(item)
} else {
None
}
}

/// # Safety
///
/// On the free-threaded build, caller must verify they have exclusive
Expand Down Expand Up @@ -589,6 +613,31 @@ impl<'py> BoundListIterator<'py> {
}
}

/// # Safety
///
/// On the free-threaded build, caller must verify they have exclusive
/// access to the list by holding a lock or by holding the innermost
/// critical section on the list.
#[inline]
#[cfg(not(Py_LIMITED_API))]
#[deny(unsafe_op_in_unsafe_fn)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should also add branches to nth and nth_back that use get_item instead of get_item_unchecked on the limited API

unsafe fn nth_back_unchecked(
index: &mut Index,
length: &mut Length,
list: &Bound<'py, PyList>,
n: usize,
) -> Option<Bound<'py, PyAny>> {
let length_size = length.0.min(list.len());
if index.0 + n < length_size {
let target_index = length_size - n - 1;
let item = unsafe { list.get_item_unchecked(target_index) };
*length = Length(target_index);
Some(item)
} else {
None
}
}

#[cfg(not(Py_LIMITED_API))]
fn with_critical_section<R>(
&mut self,
Expand Down Expand Up @@ -625,6 +674,14 @@ impl<'py> Iterator for BoundListIterator<'py> {
}
}

#[inline]
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
fn nth(&mut self, n: usize) -> Option<Self::Item> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
#[cfg(feature = "nightly")]

self.with_critical_section(|index, length, list| unsafe {
Self::nth_unchecked(index, length, list, n)
})
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
Expand Down Expand Up @@ -750,6 +807,27 @@ impl<'py> Iterator for BoundListIterator<'py> {
None
})
}

#[inline]
#[cfg(all(Py_GIL_DISABLED, feature = "nightly"))]
fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[cfg(all(Py_GIL_DISABLED, feature = "nightly"))]
#[cfg(feature = "nightly")]

I think the critical section is necessary though, since another thread might try to use the list while you're in advance_by. You don't need to be in a Py_GIL_DISABLED cfg block to use a critical section, it's a no-op on the GIL-enabled build.

self.with_critical_section(|index, length, list| {
let length = length.0.min(list.len());
let target_index = index.0 + n;
if index.0 + n < length {
let item = list.get_item(target_index);
match item {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to actually do the get_item or check for an Err value, we only need to advance the iterator, which just requires the index.0 + n < length check.

Ok(_) => {
index.0 = target_index;
Ok(())
}
Err(_) => Err(NonZero::new(n - index.0)),
}
} else {
Err(NonZero::new(n - index.0))
}
})
}
}

impl DoubleEndedIterator for BoundListIterator<'_> {
Expand All @@ -772,6 +850,14 @@ impl DoubleEndedIterator for BoundListIterator<'_> {
}
}

#[inline]
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
#[cfg(feature = "nightly")]

self.with_critical_section(|index, length, list| unsafe {
Self::nth_back_unchecked(index, length, list, n)
})
}

#[inline]
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
fn rfold<B, F>(mut self, init: B, mut f: F) -> B
Expand Down Expand Up @@ -1502,4 +1588,104 @@ mod tests {
assert!(tuple.eq(tuple_expected).unwrap());
})
}

#[test]
fn test_iter_nth() {
Python::with_gil(|py| {
let v = vec![6, 7, 8, 9, 10];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
iter.next();
assert_eq!(iter.nth(1).unwrap().extract::<i32>().unwrap(), 8);
assert_eq!(iter.nth(1).unwrap().extract::<i32>().unwrap(), 10);
assert!(iter.nth(1).is_none());

let v: Vec<i32> = vec![];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
iter.next();
assert!(iter.nth(1).is_none());

let v = vec![1, 2, 3];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert!(iter.nth(10).is_none());

let v = vec![6, 7, 8, 9, 10];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();
let mut iter = list.iter();
assert_eq!(iter.next().unwrap().extract::<i32>().unwrap(), 6);
assert_eq!(iter.nth(2).unwrap().extract::<i32>().unwrap(), 9);
assert_eq!(iter.next().unwrap().extract::<i32>().unwrap(), 10);

let mut iter = list.iter();
iter.nth_back(1);
assert_eq!(iter.nth(2).unwrap().extract::<i32>().unwrap(), 8);
assert!(iter.next().is_none());
});
}

#[test]
fn test_iter_nth_back() {
Python::with_gil(|py| {
let v = vec![1, 2, 3, 4, 5];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert_eq!(iter.nth_back(0).unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 3);
assert!(iter.nth_back(2).is_none());

let v: Vec<i32> = vec![];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert!(iter.nth_back(0).is_none());
assert!(iter.nth_back(1).is_none());

let v = vec![1, 2, 3];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert!(iter.nth_back(5).is_none());

let v = vec![1, 2, 3, 4, 5];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
iter.next_back(); // Consume the last element
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 3);
assert_eq!(iter.next_back().unwrap().extract::<i32>().unwrap(), 2);
assert_eq!(iter.nth_back(0).unwrap().extract::<i32>().unwrap(), 1);

let v = vec![1, 2, 3, 4, 5];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 4);
assert_eq!(iter.nth_back(2).unwrap().extract::<i32>().unwrap(), 1);

let mut iter2 = list.iter();
iter2.next_back();
assert_eq!(iter2.nth_back(1).unwrap().extract::<i32>().unwrap(), 3);
assert_eq!(iter2.next_back().unwrap().extract::<i32>().unwrap(), 2);

let mut iter3 = list.iter();
iter3.nth(1);
assert_eq!(iter3.nth_back(2).unwrap().extract::<i32>().unwrap(), 3);
assert!(iter3.nth_back(0).is_none());
});
}
}
Loading