Skip to content

Commit

Permalink
add pyiterate (#594)
Browse files Browse the repository at this point in the history
* add pyiterate

* add PyIterator

* Skip PyIterator tests on old julia versions

* fix IteratorSize

* fix

* fix

* compute IteratorSize in PyIterator constructor

* fix

* fix

* Update src/pyiterator.jl

Co-Authored-By: jw3126 <[email protected]>

* Update src/pyiterator.jl

Co-Authored-By: jw3126 <[email protected]>

* Update src/pyiterator.jl

Co-Authored-By: jw3126 <[email protected]>

* Update src/pyiterator.jl

Co-Authored-By: jw3126 <[email protected]>

* fix

* fix

* Update pyiterator.jl
  • Loading branch information
jw3126 authored and stevengj committed Nov 4, 2018
1 parent 48d730f commit 8eb62f2
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
70 changes: 68 additions & 2 deletions src/pyiterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#########################################################################
# Iterating over Python objects in Julia

Base.IteratorSize(::Type{PyObject}) = Base.SizeUnknown()
function _start(po::PyObject)
sigatomic_begin()
try
Expand All @@ -29,16 +30,81 @@ end

Base.done(po::PyObject, s) = ispynull(s[1])
else
function Base.iterate(po::PyObject, s=_start(po))
"""
PyIterator{T}(pyobject)
Wrap `pyobject::PyObject` into an iterator, that produces items of type `T`. To be more precise `convert(T, item)` is applied in each iteration. This can be useful to avoid automatic conversion of items into corresponding julia types.
```jldoctest
julia> using PyCall
julia> l = PyObject([PyObject(1), PyObject(2)])
PyObject [1, 2]
julia> piter = PyCall.PyIterator{PyAny}(l)
PyCall.PyIterator{PyAny,Base.HasLength()}(PyObject [1, 2])
julia> collect(piter)
2-element Array{Any,1}:
1
2
julia> piter = PyCall.PyIterator(l)
PyCall.PyIterator{PyObject,Base.HasLength()}(PyObject [1, 2])
julia> collect(piter)
2-element Array{PyObject,1}:
PyObject 1
PyObject 2
```
"""
struct PyIterator{T,S}
o::PyObject
end

function _compute_IteratorSize(o::PyObject)
S = try
length(o)
Base.HasLength
catch err
if !(err isa PyError && pyisinstance(err.val, @pyglobalobjptr :PyExc_TypeError))
rethrow()
end
Base.SizeUnknown
end
end
function PyIterator(o::PyObject)
PyIterator{PyObject}(o)
end
function (::Type{PyIterator{T}})(o::PyObject) where {T}
S = _compute_IteratorSize(o)
PyIterator{T,S}(o)
end

Base.eltype(::Type{<:PyIterator{T}}) where T = T
Base.eltype(::Type{<:PyIterator{PyAny}}) = Any
Base.length(piter::PyIterator) = length(piter.o)

Base.IteratorSize(::Type{<: PyIterator{T,S}}) where {T,S} = S()

_start(piter::PyIterator) = _start(piter.o)

function Base.iterate(piter::PyIterator{T}, s=_start(piter)) where {T}
ispynull(s[1]) && return nothing
sigatomic_begin()
try
nxt = PyObject(@pycheck ccall((@pysym :PyIter_Next), PyPtr, (PyPtr,), s[2]))
return (convert(PyAny, s[1]), (nxt, s[2]))
return (convert(T,s[1]), (nxt, s[2]))
finally
sigatomic_end()
end
end
function Base.iterate(po::PyObject, s=_start(po))
# avoid the constructor that calls length
# since that might be an expensive operation
# even if length is cheap, this adds 10% performance
piter = PyIterator{PyAny, Base.SizeUnknown}(po)
iterate(piter, s)
end
end

# issue #216
Expand Down
39 changes: 39 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,45 @@ def try_call(f):
pybuiltin("Exception"))
end

@static if VERSION < v"0.7.0-DEV.5126" # julia#25261
# PyIterator not defined in this julia version
else
@testset "PyIterator" begin
arr = [1,2]
o = PyObject(arr)
c_pyany = collect(PyCall.PyIterator{PyAny}(o))
@test c_pyany == arr
@test c_pyany[1] isa Integer
@test c_pyany[2] isa Integer

c_f64 = collect(PyCall.PyIterator{Float64}(o))
@test c_f64 == arr
@test eltype(c_f64) == Float64

i1 = PyObject([1])
i2 = PyObject([2])
l = PyObject([i1,i2])

piter = PyCall.PyIterator(l)
@test length(piter) == 2
@test length(collect(piter)) == 2
r1, r2 = collect(piter)
@test r1.o === i1.o
@test r2.o === i2.o

@test Base.IteratorSize(PyCall.PyIterator(PyObject(1))) == Base.SizeUnknown()
@test Base.IteratorSize(PyCall.PyIterator(PyObject([1]))) == Base.HasLength()

# 594
@test collect(zip(py"iter([1, 2, 3])", 1:3)) ==
[(1, 1), (2, 2), (3, 3)]
@test collect(zip(PyCall.PyIterator{Int}(py"iter([1, 2, 3])"), 1:3)) ==
[(1, 1), (2, 2), (3, 3)]
@test collect(zip(PyCall.PyIterator(py"[1, 2, 3]"o), 1:3)) ==
[(1, 1), (2, 2), (3, 3)]
end
end

@testset "atexit" begin
if VERSION < v"0.7-"
setup = ""
Expand Down

0 comments on commit 8eb62f2

Please sign in to comment.