-
Notifications
You must be signed in to change notification settings - Fork 784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize nth and nth_back for BoundListIterator #4810
base: main
Are you sure you want to change the base?
Changes from 3 commits
cc9cabd
3a0c196
40d38f3
1b19616
f6e95a8
b0c749b
b2bf973
4e86709
e4269c2
6e18229
5bab05b
0b23173
e88f8be
3a7a171
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Optimizes `nth` and `nth_back` for `BoundListIterator` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -494,7 +494,6 @@ impl<'py> Iterator for BoundListIterator<'py> { | |
#[inline] | ||
fn next(&mut self) -> Option<Self::Item> { | ||
let length = self.length.min(self.list.len()); | ||
|
||
if self.index < length { | ||
let item = unsafe { self.get_item(self.index) }; | ||
self.index += 1; | ||
|
@@ -509,6 +508,20 @@ impl<'py> Iterator for BoundListIterator<'py> { | |
let len = self.len(); | ||
(len, Some(len)) | ||
} | ||
|
||
#[inline] | ||
fn nth(&mut self, n: usize) -> Option<Self::Item> { | ||
let length = self.length.min(self.list.len()); | ||
let target_index = self.index + n; | ||
if self.index + n < length { | ||
let item = unsafe { self.get_item(target_index) }; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder, is there a time-of-check to time-of-use bug here on the length? Not one for this PR, but a follow up I will try not to forget... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup I think that's a potential TOCTOU bug here. This particular implementation assumes the user ensures proper synchronization if they intend to use the iterator in a multi-threaded or mutable environment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the current implementation of https://github.com/PyO3/pyo3/blob/main/src/types/list.rs#L495-L498 |
||
self.index = target_index + 1; | ||
Some(item) | ||
} else { | ||
self.index = self.list.len(); | ||
None | ||
} | ||
} | ||
} | ||
|
||
impl DoubleEndedIterator for BoundListIterator<'_> { | ||
|
@@ -524,6 +537,20 @@ impl DoubleEndedIterator for BoundListIterator<'_> { | |
None | ||
} | ||
} | ||
|
||
#[inline] | ||
fn nth_back(&mut self, n: usize) -> Option<Self::Item> { | ||
let length = self.length.min(self.list.len()); | ||
if self.index + n < length { | ||
let target_index = length - n - 1; | ||
let item = unsafe { self.get_item(target_index) }; | ||
self.length = target_index; | ||
Some(item) | ||
} else { | ||
self.length = length; | ||
None | ||
} | ||
} | ||
} | ||
|
||
impl ExactSizeIterator for BoundListIterator<'_> { | ||
|
@@ -720,6 +747,106 @@ mod tests { | |
}); | ||
} | ||
|
||
#[test] | ||
fn test_iter_nth() { | ||
Python::with_gil(|py| { | ||
let v = vec![6, 7, 8, 9, 10]; | ||
let ob = (&v).into_pyobject(py).unwrap(); | ||
let list = ob.downcast::<PyList>().unwrap(); | ||
|
||
let mut iter = list.iter(); | ||
iter.next(); | ||
assert_eq!(iter.nth(1).unwrap().extract::<i32>().unwrap(), 8); | ||
assert_eq!(iter.nth(1).unwrap().extract::<i32>().unwrap(), 10); | ||
assert!(iter.nth(1).is_none()); | ||
|
||
let v: Vec<i32> = vec![]; | ||
let ob = (&v).into_pyobject(py).unwrap(); | ||
let list = ob.downcast::<PyList>().unwrap(); | ||
|
||
let mut iter = list.iter(); | ||
iter.next(); | ||
assert!(iter.nth(1).is_none()); | ||
|
||
let v = vec![1, 2, 3]; | ||
let ob = (&v).into_pyobject(py).unwrap(); | ||
let list = ob.downcast::<PyList>().unwrap(); | ||
|
||
let mut iter = list.iter(); | ||
assert!(iter.nth(10).is_none()); | ||
|
||
let v = vec![6, 7, 8, 9, 10]; | ||
let ob = (&v).into_pyobject(py).unwrap(); | ||
let list = ob.downcast::<PyList>().unwrap(); | ||
let mut iter = list.iter(); | ||
assert_eq!(iter.next().unwrap().extract::<i32>().unwrap(), 6); | ||
assert_eq!(iter.nth(2).unwrap().extract::<i32>().unwrap(), 9); | ||
assert_eq!(iter.next().unwrap().extract::<i32>().unwrap(), 10); | ||
|
||
let mut iter = list.iter(); | ||
iter.nth_back(1); | ||
assert_eq!(iter.nth(2).unwrap().extract::<i32>().unwrap(), 8); | ||
assert!(iter.next().is_none()); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn test_iter_nth_back() { | ||
Python::with_gil(|py| { | ||
let v = vec![1, 2, 3, 4, 5]; | ||
let ob = (&v).into_pyobject(py).unwrap(); | ||
let list = ob.downcast::<PyList>().unwrap(); | ||
|
||
let mut iter = list.iter(); | ||
assert_eq!(iter.nth_back(0).unwrap().extract::<i32>().unwrap(), 5); | ||
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 3); | ||
assert!(iter.nth_back(2).is_none()); | ||
|
||
let v: Vec<i32> = vec![]; | ||
let ob = (&v).into_pyobject(py).unwrap(); | ||
let list = ob.downcast::<PyList>().unwrap(); | ||
|
||
let mut iter = list.iter(); | ||
assert!(iter.nth_back(0).is_none()); | ||
assert!(iter.nth_back(1).is_none()); | ||
|
||
let v = vec![1, 2, 3]; | ||
let ob = (&v).into_pyobject(py).unwrap(); | ||
let list = ob.downcast::<PyList>().unwrap(); | ||
|
||
let mut iter = list.iter(); | ||
assert!(iter.nth_back(5).is_none()); | ||
|
||
let v = vec![1, 2, 3, 4, 5]; | ||
let ob = (&v).into_pyobject(py).unwrap(); | ||
let list = ob.downcast::<PyList>().unwrap(); | ||
|
||
let mut iter = list.iter(); | ||
iter.next_back(); // Consume the last element | ||
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 3); | ||
assert_eq!(iter.next_back().unwrap().extract::<i32>().unwrap(), 2); | ||
assert_eq!(iter.nth_back(0).unwrap().extract::<i32>().unwrap(), 1); | ||
|
||
let v = vec![1, 2, 3, 4, 5]; | ||
let ob = (&v).into_pyobject(py).unwrap(); | ||
let list = ob.downcast::<PyList>().unwrap(); | ||
|
||
let mut iter = list.iter(); | ||
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 4); | ||
assert_eq!(iter.nth_back(2).unwrap().extract::<i32>().unwrap(), 1); | ||
|
||
let mut iter2 = list.iter(); | ||
iter2.next_back(); | ||
assert_eq!(iter2.nth_back(1).unwrap().extract::<i32>().unwrap(), 3); | ||
assert_eq!(iter2.next_back().unwrap().extract::<i32>().unwrap(), 2); | ||
|
||
let mut iter3 = list.iter(); | ||
iter3.nth(1); | ||
assert_eq!(iter3.nth_back(2).unwrap().extract::<i32>().unwrap(), 3); | ||
assert!(iter3.nth_back(0).is_none()); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn test_iter_rev() { | ||
Python::with_gil(|py| { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally missed this when I added the new free-threaded implementations for folding operations in list. I'll add benchmarks for those in a separate PR.