Skip to content

Commit

Permalink
Make return type of map inferrable with heterogeneous arrays (#42046)
Browse files Browse the repository at this point in the history
Inference is not able to detect the element type automatically,
but we can do it manually since we know promote_typejoin is used for widening.
This is similar to the approach used for `broadcast` at #30485.
  • Loading branch information
nalimilan authored Sep 1, 2021
1 parent 2a0ab37 commit 49e3aec
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 54 deletions.
15 changes: 10 additions & 5 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -739,18 +739,19 @@ if isdefined(Core, :Compiler)
I = esc(itr)
return quote
if $I isa Generator && ($I).f isa Type
($I).f
T = ($I).f
else
Core.Compiler.return_type(_iterator_upper_bound, Tuple{typeof($I)})
T = Core.Compiler.return_type(_iterator_upper_bound, Tuple{typeof($I)})
end
promote_typejoin_union(T)
end
end
else
macro default_eltype(itr)
I = esc(itr)
return quote
if $I isa Generator && ($I).f isa Type
($I).f
promote_typejoin_union($I.f)
else
Any
end
Expand All @@ -775,8 +776,12 @@ function collect(itr::Generator)
return _array_for(et, itr.iter, isz)
end
v1, st = y
arr = _array_for(typeof(v1), itr.iter, isz, shape)
return collect_to_with_first!(arr, v1, itr, st)
dest = _array_for(typeof(v1), itr.iter, isz, shape)
# The typeassert gives inference a helping hand on the element type and dimensionality
# (work-around for #28382)
et′ = et <: Type ? Type : et
RT = dest isa AbstractArray ? AbstractArray{<:et′, ndims(dest)} : Any
collect_to_with_first!(dest, v1, itr, st)::RT
end
end

Expand Down
46 changes: 1 addition & 45 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Module containing the broadcasting implementation.
module Broadcast

using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, promote_typejoin_union, @pure,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, BroadcastFunction
Expand Down Expand Up @@ -713,50 +713,6 @@ eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])}
eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])}
eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...}

function promote_typejoin_union(::Type{T}) where T
if T === Union{}
return Union{}
elseif T isa UnionAll
return Any # TODO: compute more precise bounds
elseif T isa Union
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
elseif T <: Tuple
return typejoin_union_tuple(T)
else
return T
end
end

@pure function typejoin_union_tuple(T::Type)
u = Base.unwrap_unionall(T)
u isa Union && return typejoin(
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
p = (u::DataType).parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
c = Vector{Any}(undef, lr)
for i = 1:lr
pi = p[i]
U = Core.Compiler.unwrapva(pi)
if U === Union{}
ci = Union{}
elseif U isa Union
ci = typejoin(U.a, U.b)
else
ci = U
end
if i == lr && Core.Compiler.isvarargtype(pi)
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
else
c[i] = ci
end
end
return Base.rewrap_unionall(Tuple{c...}, T)
end

# Inferred eltype of result of broadcast(f, args...)
combine_eltypes(f, args::Tuple) =
promote_typejoin_union(Base._return_type(f, eltypes(args)))
Expand Down
44 changes: 44 additions & 0 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,50 @@ function promote_typejoin(@nospecialize(a), @nospecialize(b))
end
_promote_typesubtract(@nospecialize(a)) = typesplit(a, Union{Nothing, Missing})

function promote_typejoin_union(::Type{T}) where T
if T === Union{}
return Union{}
elseif T isa UnionAll
return Any # TODO: compute more precise bounds
elseif T isa Union
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
elseif T <: Tuple
return typejoin_union_tuple(T)
else
return T
end
end

function typejoin_union_tuple(T::Type)
@_pure_meta
u = Base.unwrap_unionall(T)
u isa Union && return typejoin(
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
p = (u::DataType).parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
c = Vector{Any}(undef, lr)
for i = 1:lr
pi = p[i]
U = Core.Compiler.unwrapva(pi)
if U === Union{}
ci = Union{}
elseif U isa Union
ci = typejoin(U.a, U.b)
else
ci = U
end
if i == lr && Core.Compiler.isvarargtype(pi)
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
else
c[i] = ci
end
end
return Base.rewrap_unionall(Tuple{c...}, T)
end

# Returns length, isfixed
function full_va_len(p)
Expand Down
4 changes: 0 additions & 4 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -991,10 +991,6 @@ end
@test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test isequal([1, 2] + [3.0, missing], [4.0, missing])
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
Expand Down
22 changes: 22 additions & 0 deletions test/generic_map_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ function generic_map_tests(mapf, inplace_mapf=nothing)
@test A == map(x->x*x*x, Float64[1:10...])
@test A === B
end

# Issue #28382: inferrability of map with Union eltype
@test isequal(map(+, [1, 2], [3.0, missing]), [4.0, missing])
@test Core.Compiler.return_type(map, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test isequal(map(tuple, [1, 2], [3.0, missing]), [(1, 3.0), (2, missing)])
@test Core.Compiler.return_type(map, Tuple{typeof(tuple), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Tuple{Int, Any}}
# Check that corner cases do not throw an error
@test isequal(map(x -> x === 1 ? nothing : x, [1, 2, missing]),
[nothing, 2, missing])
@test isequal(map(x -> x === 1 ? nothing : x, Any[1, 2, 3.0, missing]),
[nothing, 2, 3, missing])
@test map((x,y)->(x==1 ? 1.0 : x, y), [1, 2, 3], ["a", "b", "c"]) ==
[(1.0, "a"), (2, "b"), (3, "c")]
@test map(typeof, [iszero, isdigit]) == [typeof(iszero), typeof(isdigit)]
@test map(typeof, [iszero, iszero]) == [typeof(iszero), typeof(iszero)]
@test isequal(map(identity, Vector{<:Union{Int, Missing}}[[1, 2],[missing, 1]]),
[[1, 2],[missing, 1]])
@test map(x -> x < 0 ? false : x, Int[]) isa Vector{Integer}
end

function testmap_equivalence(mapf, f, c...)
Expand Down
1 change: 1 addition & 0 deletions test/sets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using Dates
@test isa(Set(sin(x) for x = 1:3), Set{Float64})
@test isa(Set(f17741(x) for x = 1:3), Set{Int})
@test isa(Set(f17741(x) for x = -1:1), Set{Integer})
@test isa(Set(f17741(x) for x = 1:0), Set{Integer})
end
let s1 = Set(["foo", "bar"]), s2 = Set(s1)
@test s1 == s2
Expand Down

0 comments on commit 49e3aec

Please sign in to comment.