From 8eb62f2676e0ed6ff879a52ff471caa0a1df7e2e Mon Sep 17 00:00:00 2001 From: Jan Weidner Date: Sun, 4 Nov 2018 19:43:01 +0100 Subject: [PATCH] add pyiterate (#594) * add pyiterate * add PyIterator * Skip PyIterator tests on old julia versions * fix IteratorSize * fix * fix * compute IteratorSize in PyIterator constructor * fix * fix * Update src/pyiterator.jl Co-Authored-By: jw3126 * Update src/pyiterator.jl Co-Authored-By: jw3126 * Update src/pyiterator.jl Co-Authored-By: jw3126 * Update src/pyiterator.jl Co-Authored-By: jw3126 * fix * fix * Update pyiterator.jl --- src/pyiterator.jl | 70 +++++++++++++++++++++++++++++++++++++++++++++-- test/runtests.jl | 39 ++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 2 deletions(-) diff --git a/src/pyiterator.jl b/src/pyiterator.jl index ccce277e..5b4e44d6 100644 --- a/src/pyiterator.jl +++ b/src/pyiterator.jl @@ -3,6 +3,7 @@ ######################################################################### # Iterating over Python objects in Julia +Base.IteratorSize(::Type{PyObject}) = Base.SizeUnknown() function _start(po::PyObject) sigatomic_begin() try @@ -29,16 +30,81 @@ end Base.done(po::PyObject, s) = ispynull(s[1]) else - function Base.iterate(po::PyObject, s=_start(po)) + """ + PyIterator{T}(pyobject) + + Wrap `pyobject::PyObject` into an iterator, that produces items of type `T`. To be more precise `convert(T, item)` is applied in each iteration. This can be useful to avoid automatic conversion of items into corresponding julia types. + ```jldoctest + julia> using PyCall + + julia> l = PyObject([PyObject(1), PyObject(2)]) + PyObject [1, 2] + + julia> piter = PyCall.PyIterator{PyAny}(l) + PyCall.PyIterator{PyAny,Base.HasLength()}(PyObject [1, 2]) + + julia> collect(piter) + 2-element Array{Any,1}: + 1 + 2 + + julia> piter = PyCall.PyIterator(l) + PyCall.PyIterator{PyObject,Base.HasLength()}(PyObject [1, 2]) + + julia> collect(piter) + 2-element Array{PyObject,1}: + PyObject 1 + PyObject 2 + ``` + """ + struct PyIterator{T,S} + o::PyObject + end + + function _compute_IteratorSize(o::PyObject) + S = try + length(o) + Base.HasLength + catch err + if !(err isa PyError && pyisinstance(err.val, @pyglobalobjptr :PyExc_TypeError)) + rethrow() + end + Base.SizeUnknown + end + end + function PyIterator(o::PyObject) + PyIterator{PyObject}(o) + end + function (::Type{PyIterator{T}})(o::PyObject) where {T} + S = _compute_IteratorSize(o) + PyIterator{T,S}(o) + end + + Base.eltype(::Type{<:PyIterator{T}}) where T = T + Base.eltype(::Type{<:PyIterator{PyAny}}) = Any + Base.length(piter::PyIterator) = length(piter.o) + + Base.IteratorSize(::Type{<: PyIterator{T,S}}) where {T,S} = S() + + _start(piter::PyIterator) = _start(piter.o) + + function Base.iterate(piter::PyIterator{T}, s=_start(piter)) where {T} ispynull(s[1]) && return nothing sigatomic_begin() try nxt = PyObject(@pycheck ccall((@pysym :PyIter_Next), PyPtr, (PyPtr,), s[2])) - return (convert(PyAny, s[1]), (nxt, s[2])) + return (convert(T,s[1]), (nxt, s[2])) finally sigatomic_end() end end + function Base.iterate(po::PyObject, s=_start(po)) + # avoid the constructor that calls length + # since that might be an expensive operation + # even if length is cheap, this adds 10% performance + piter = PyIterator{PyAny, Base.SizeUnknown}(po) + iterate(piter, s) + end end # issue #216 diff --git a/test/runtests.jl b/test/runtests.jl index 9f9d2b30..7de01831 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -687,6 +687,45 @@ def try_call(f): pybuiltin("Exception")) end +@static if VERSION < v"0.7.0-DEV.5126" # julia#25261 + # PyIterator not defined in this julia version +else + @testset "PyIterator" begin + arr = [1,2] + o = PyObject(arr) + c_pyany = collect(PyCall.PyIterator{PyAny}(o)) + @test c_pyany == arr + @test c_pyany[1] isa Integer + @test c_pyany[2] isa Integer + + c_f64 = collect(PyCall.PyIterator{Float64}(o)) + @test c_f64 == arr + @test eltype(c_f64) == Float64 + + i1 = PyObject([1]) + i2 = PyObject([2]) + l = PyObject([i1,i2]) + + piter = PyCall.PyIterator(l) + @test length(piter) == 2 + @test length(collect(piter)) == 2 + r1, r2 = collect(piter) + @test r1.o === i1.o + @test r2.o === i2.o + + @test Base.IteratorSize(PyCall.PyIterator(PyObject(1))) == Base.SizeUnknown() + @test Base.IteratorSize(PyCall.PyIterator(PyObject([1]))) == Base.HasLength() + + # 594 + @test collect(zip(py"iter([1, 2, 3])", 1:3)) == + [(1, 1), (2, 2), (3, 3)] + @test collect(zip(PyCall.PyIterator{Int}(py"iter([1, 2, 3])"), 1:3)) == + [(1, 1), (2, 2), (3, 3)] + @test collect(zip(PyCall.PyIterator(py"[1, 2, 3]"o), 1:3)) == + [(1, 1), (2, 2), (3, 3)] + end +end + @testset "atexit" begin if VERSION < v"0.7-" setup = ""