Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pyiterate #594

Merged
merged 17 commits into from
Nov 4, 2018
67 changes: 65 additions & 2 deletions src/pyiterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#########################################################################
# Iterating over Python objects in Julia

Base.IteratorSize(::Type{PyObject}) = Base.SizeUnknown()
function _start(po::PyObject)
sigatomic_begin()
try
Expand All @@ -29,16 +30,78 @@ end

Base.done(po::PyObject, s) = ispynull(s[1])
else
function Base.iterate(po::PyObject, s=_start(po))
"""
PyIterator{T}(pyobject)

Wrap `pyobject::PyObject` into an iterator, that produces items of type `T`. To be more precise `convert(T, item)` is applied in each iteration. This can be useful to avoid automatic conversion of items into corresponding julia types.
```jldoctest
julia> using PyCall

julia> l = PyObject([PyObject(1), PyObject(2)])
PyObject [1, 2]

julia> piter = PyCall.PyIterator{PyAny}(l)
PyCall.PyIterator{PyAny,Base.HasLength()}(PyObject [1, 2])

julia> collect(piter)
2-element Array{Any,1}:
1
2

julia> piter = PyCall.PyIterator(l)
PyCall.PyIterator{PyObject,Base.HasLength()}(PyObject [1, 2])

julia> collect(piter)
2-element Array{PyObject,1}:
PyObject 1
PyObject 2
```
"""
struct PyIterator{T,S}
o::PyObject
end

function _compute_IteratorSize(o::PyObject)
S = try
length(o)
Base.HasLength()
jw3126 marked this conversation as resolved.
Show resolved Hide resolved
catch err
jw3126 marked this conversation as resolved.
Show resolved Hide resolved
if !(err isa PyError && pyisinstance(err.val, @pyglobalobjptr :PyExc_TypeError))
rethrow()
end
Base.SizeUnknown()
jw3126 marked this conversation as resolved.
Show resolved Hide resolved
end
end
function PyIterator(o::PyObject)
PyIterator{PyObject}(o)
end
function (::Type{PyIterator{T}})(o::PyObject) where {T}
S = _compute_IteratorSize(o)
PyIterator{T,S}(o)
end

Base.eltype(::Type{<:PyIterator{T}}) where T = T
Base.eltype(::Type{<:PyIterator{PyAny}}) = Any
Base.length(piter::PyIterator) = length(piter.o)

jw3126 marked this conversation as resolved.
Show resolved Hide resolved
Base.IteratorSize(::Type{<: PyIterator{T,S}}) where {T,S} = S
jw3126 marked this conversation as resolved.
Show resolved Hide resolved

_start(piter::PyIterator) = _start(piter.o)

function Base.iterate(piter::PyIterator{T}, s=_start(piter)) where {T}
ispynull(s[1]) && return nothing
sigatomic_begin()
try
nxt = PyObject(@pycheck ccall((@pysym :PyIter_Next), PyPtr, (PyPtr,), s[2]))
return (convert(PyAny, s[1]), (nxt, s[2]))
return (convert(T,s[1]), (nxt, s[2]))
finally
sigatomic_end()
end
end
function Base.iterate(po::PyObject, s=_start(po))
piter = PyIterator{PyAny}(po)
iterate(piter, s)
end
end

# issue #216
Expand Down
40 changes: 40 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -666,4 +666,44 @@ def try_call(f):
pybuiltin("Exception"))
end

@static if VERSION < v"0.7.0-DEV.5126" # julia#25261
# PyIterator not defined in this julia version
else
@testset "PyIterator" begin
arr = [1,2]
o = PyObject(arr)
c_pyany = collect(PyCall.PyIterator{PyAny}(o))
@test c_pyany == arr
@test c_pyany[1] isa Integer
@test c_pyany[2] isa Integer

c_f64 = collect(PyCall.PyIterator{Float64}(o))
@test c_f64 == arr
@test eltype(c_f64) == Float64

i1 = PyObject([1])
i2 = PyObject([2])
l = PyObject([i1,i2])

piter = PyCall.PyIterator(l)
@test length(piter) == 2
@test length(collect(piter)) == 2
r1, r2 = collect(piter)
@test r1.o === i1.o
@test r2.o === i2.o

@test Base.IteratorSize(PyCall.PyIterator(PyObject(1))) == Base.SizeUnknown()
@test Base.IteratorSize(PyCall.PyIterator(PyObject([1]))) == Base.HasLength()

# 594
@test collect(zip(py"iter([1, 2, 3])", 1:3)) ==
[(1, 1), (2, 2), (3, 3)]
@test collect(zip(PyCall.PyIterator{Int}(py"iter([1, 2, 3])"), 1:3)) ==
[(1, 1), (2, 2), (3, 3)]
@test collect(zip(PyCall.PyIterator(py"[1, 2, 3]"o), 1:3)) ==
[(1, 1), (2, 2), (3, 3)]
end

end

include("test_pyfncall.jl")