Skip to content

Commit

Permalink
Restructure of the promotion mechanism for broadcast (#18642)
Browse files Browse the repository at this point in the history
* Restructure the promotion mechanism for broadcast

* More broadcast tests

* Use broadcast for element wise operators where appropriate
  • Loading branch information
pabloferz authored and stevengj committed Nov 7, 2016
1 parent 410b39c commit d16d994
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 56 deletions.
6 changes: 1 addition & 5 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1762,12 +1762,8 @@ end
# These are needed because map(eltype, As) is not inferrable
promote_eltype_op(::Any) = (@_pure_meta; Any)
promote_eltype_op(op, A) = (@_pure_meta; promote_op(op, eltype(A)))
promote_eltype_op{T}(op, ::AbstractArray{T}) = (@_pure_meta; promote_op(op, T))
promote_eltype_op{T}(op, ::AbstractArray{T}, A) = (@_pure_meta; promote_op(op, T, eltype(A)))
promote_eltype_op{T}(op, A, ::AbstractArray{T}) = (@_pure_meta; promote_op(op, eltype(A), T))
promote_eltype_op{R,S}(op, ::AbstractArray{R}, ::AbstractArray{S}) = (@_pure_meta; promote_op(op, R, S))
promote_eltype_op(op, A, B) = (@_pure_meta; promote_op(op, eltype(A), eltype(B)))
promote_eltype_op(op, A, B, C, D...) = (@_pure_meta; promote_eltype_op(op, promote_eltype_op(op, A, B), C, D...))
promote_eltype_op(op, A, B, C, D...) = (@_pure_meta; promote_eltype_op(op, eltype(A), promote_eltype_op(op, B, C, D...)))

## 1 argument

Expand Down
69 changes: 47 additions & 22 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module Broadcast

using Base.Cartesian
using Base: promote_eltype_op, linearindices, tail, OneTo, to_shape,
using Base: promote_eltype_op, _default_eltype, linearindices, tail, OneTo, to_shape,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache
import Base: .+, .-, .*, ./, .\, .//, .==, .<, .!=, .<=, , .%, .<<, .>>, .^
import Base: broadcast
Expand All @@ -16,7 +16,7 @@ export broadcast_getindex, broadcast_setindex!
broadcast(f) = f()
@inline broadcast(f, x::Number...) = f(x...)
@inline broadcast{N}(f, t::NTuple{N}, ts::Vararg{NTuple{N}}) = map(f, t, ts...)
@inline broadcast(f, As::AbstractArray...) = broadcast_t(f, promote_eltype_op(f, As...), As...)
@inline broadcast(f, As::AbstractArray...) = broadcast_c(f, Array, As...)

# special cases for "X .= ..." (broadcast!) assignments
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
Expand Down Expand Up @@ -127,14 +127,14 @@ Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]
## Broadcasting core
# nargs encodes the number of As arguments (which matches the number
# of keeps). The first two type parameters are to ensure specialization.
@generated function _broadcast!{K,ID,AT,nargs}(f, B::AbstractArray, keeps::K, Idefaults::ID, As::AT, ::Type{Val{nargs}})
@generated function _broadcast!{K,ID,AT,nargs}(f, B::AbstractArray, keeps::K, Idefaults::ID, As::AT, ::Type{Val{nargs}}, iter)
quote
$(Expr(:meta, :noinline))
# destructure the keeps and As tuples
@nexprs $nargs i->(A_i = As[i])
@nexprs $nargs i->(keep_i = keeps[i])
@nexprs $nargs i->(Idefault_i = Idefaults[i])
@simd for I in CartesianRange(indices(B))
@simd for I in iter
# reverse-broadcast the indices
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
# extract array values
Expand All @@ -148,7 +148,7 @@ end

# For BitArray outputs, we cache the result in a "small" Vector{Bool},
# and then copy in chunks into the output
@generated function _broadcast!{K,ID,AT,nargs}(f, B::BitArray, keeps::K, Idefaults::ID, As::AT, ::Type{Val{nargs}})
@generated function _broadcast!{K,ID,AT,nargs}(f, B::BitArray, keeps::K, Idefaults::ID, As::AT, ::Type{Val{nargs}}, iter)
quote
$(Expr(:meta, :noinline))
# destructure the keeps and As tuples
Expand All @@ -159,7 +159,7 @@ end
Bc = B.chunks
ind = 1
cind = 1
@simd for I in CartesianRange(indices(B))
@simd for I in iter
# reverse-broadcast the indices
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
# extract array values
Expand Down Expand Up @@ -193,12 +193,12 @@ as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
shape = indices(B)
check_broadcast_indices(shape, As...)
keeps, Idefaults = map_newindexer(shape, As)
_broadcast!(f, B, keeps, Idefaults, As, Val{nargs})
B
iter = CartesianRange(shape)
_broadcast!(f, B, keeps, Idefaults, As, Val{nargs}, iter)
return B
end

# broadcast with computed element type

@generated function _broadcast!{K,ID,AT,nargs}(f, B::AbstractArray, keeps::K, Idefaults::ID, As::AT, ::Type{Val{nargs}}, iter, st, count)
quote
$(Expr(:meta, :noinline))
Expand Down Expand Up @@ -233,12 +233,8 @@ end
end
end

function broadcast_t(f, ::Type{Any}, As...)
shape = broadcast_indices(As...)
iter = CartesianRange(shape)
if isempty(iter)
return similar(Array{Any}, shape)
end
# broadcast methods that dispatch on the type found by inference
function broadcast_t(f, ::Type{Any}, shape, iter, As...)
nargs = length(As)
keeps, Idefaults = map_newindexer(shape, As)
st = start(iter)
Expand All @@ -248,17 +244,46 @@ function broadcast_t(f, ::Type{Any}, As...)
B[I] = val
return _broadcast!(f, B, keeps, Idefaults, As, Val{nargs}, iter, st, 1)
end
@inline function broadcast_t(f, T, shape, iter, As...)
B = similar(Array{T}, shape)
nargs = length(As)
keeps, Idefaults = map_newindexer(shape, As)
_broadcast!(f, B, keeps, Idefaults, As, Val{nargs}, iter)
return B
end

@inline broadcast_t(f, T, As...) = broadcast!(f, similar(Array{T}, broadcast_indices(As...)), As...)

# broadcast method that uses inference to find the type, but preserves abstract
# container types when possible (used by binary elementwise operators)
@inline broadcast_elwise_op(f, As...) =
broadcast!(f, similar(Array{promote_eltype_op(f, As...)}, broadcast_indices(As...)), As...)

This comment has been minimized.

Copy link
@stevengj

stevengj Dec 8, 2016

Member

I'm not sure what to do with this in #17623, where .+ etcetera simply call broadcast. I could make the parser transform them to calls to broadcast_elwise_op instead, but since this won't work with fusion I'm not sure I see the point.


ftype(f, A) = typeof(a -> f(a))
ftype(f, A...) = typeof(a -> f(a...))
ftype(T::DataType, A) = Type{T}
ftype(T::DataType, A...) = Type{T}
ziptype(A) = Tuple{eltype(A)}
ziptype(A, B) = Iterators.Zip2{Tuple{eltype(A)}, Tuple{eltype(B)}}
@inline ziptype(A, B, C, D...) = Iterators.Zip{Tuple{eltype(A)}, ziptype(B, C, D...)}

# broadcast methods that dispatch on the type of the final container
@inline function broadcast_c(f, ::Type{Array}, As...)
T = _default_eltype(Base.Generator{ziptype(As...), ftype(f, As...)})
shape = broadcast_indices(As...)
iter = CartesianRange(shape)
if isleaftype(T)
return broadcast_t(f, T, shape, iter, As...)
end
if isempty(iter)
return similar(Array{T}, shape)
end
return broadcast_t(f, Any, shape, iter, As...)
end
function broadcast_c(f, ::Type{Tuple}, As...)
shape = broadcast_indices(As...)
check_broadcast_indices(shape, As...)
n = length(shape[1])
return ntuple(k->f((_broadcast_getindex(A, k) for A in As)...), n)
end
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
@inline broadcast_c(f, ::Type{Array}, As...) = broadcast_t(f, promote_eltype_op(f, As...), As...)

"""
broadcast(f, As...)
Expand Down Expand Up @@ -441,10 +466,10 @@ end
## elementwise operators ##

for op in (:÷, :%, :<<, :>>, :-, :/, :\, ://, :^)
@eval $(Symbol(:., op))(A::AbstractArray, B::AbstractArray) = broadcast($op, A, B)
@eval $(Symbol(:., op))(A::AbstractArray, B::AbstractArray) = broadcast_elwise_op($op, A, B)
end
.+(As::AbstractArray...) = broadcast(+, As...)
.*(As::AbstractArray...) = broadcast(*, As...)
.+(As::AbstractArray...) = broadcast_elwise_op(+, As...)
.*(As::AbstractArray...) = broadcast_elwise_op(*, As...)

# ## element-wise comparison operators returning BitArray ##

Expand Down
11 changes: 11 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,17 @@ end))
@deprecate_binding cycle Iterators.cycle
@deprecate_binding repeated Iterators.repeated

# promote_op method where the operator is also a type
function promote_op(op::Type, Ts::Type...)
depwarn("promote_op(op::Type, ::Type...) is deprecated as it is no " *
"longer needed in Base. If you need its functionality, consider " *
"defining it locally.", :promote_op)
if isdefined(Core, :Inference)
return Core.Inference.return_type(op, Tuple{Ts...})
end
return op
end

# NOTE: Deprecation of Channel{T}() is implemented in channels.jl.
# To be removed from there when 0.6 deprecations are removed.

Expand Down
1 change: 1 addition & 0 deletions base/nullable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ convert( ::Type{Nullable }, ::Void) = Nullable{Union{}}()
promote_rule{S,T}(::Type{Nullable{S}}, ::Type{T}) = Nullable{promote_type(S, T)}
promote_rule{S,T}(::Type{Nullable{S}}, ::Type{Nullable{T}}) = Nullable{promote_type(S, T)}
promote_op{S,T}(op::Any, ::Type{Nullable{S}}, ::Type{Nullable{T}}) = Nullable{promote_op(op, S, T)}
promote_op{S,T}(op::Type, ::Type{Nullable{S}}, ::Type{Nullable{T}}) = Nullable{promote_op(op, S, T)}

function show{T}(io::IO, x::Nullable{T})
if get(io, :compact, false)
Expand Down
29 changes: 9 additions & 20 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,34 +217,23 @@ max(x::Real, y::Real) = max(promote(x,y)...)
min(x::Real, y::Real) = min(promote(x,y)...)
minmax(x::Real, y::Real) = minmax(promote(x, y)...)

# "Promotion" that takes a function into account. These are meant to be
# used mainly by broadcast methods, so it is advised against overriding them
if isdefined(Core, :Inference)
function _promote_op(op, T::ANY)
G = Tuple{Generator{Tuple{T},typeof(op)}}
return Core.Inference.return_type(first, G)
end
function _promote_op(op, R::ANY, S::ANY)
F = typeof(a -> op(a...))
G = Tuple{Generator{Iterators.Zip2{Tuple{R},Tuple{S}},F}}
return Core.Inference.return_type(first, G)
end
else
_promote_op(::ANY...) = (@_pure_meta; Any)
end
# "Promotion" that takes a function into account and tries to preserve
# non-concrete types. These are meant to be used mainly by elementwise
# operations, so it is advised against overriding them
_default_type(T::Type) = (@_pure_meta; T)

promote_op(::Any...) = (@_pure_meta; Any)
promote_op(T::Type, ::Any) = (@_pure_meta; T)
promote_op(T::Type, ::Type) = (@_pure_meta; T) # To handle ambiguities
# Promotion that tries to preserve non-concrete types
function promote_op{S}(f, ::Type{S})
T = _promote_op(f, _default_type(S))
@_pure_meta
Z = Tuple{_default_type(S)}
T = _default_eltype(Generator{Z, typeof(a -> f(a))})
isleaftype(S) && return isleaftype(T) ? T : Any
return typejoin(S, T)
end
function promote_op{R,S}(f, ::Type{R}, ::Type{S})
T = _promote_op(f, _default_type(R), _default_type(S))
@_pure_meta
Z = Iterators.Zip2{Tuple{_default_type(R)}, Tuple{_default_type(S)}}
T = _default_eltype(Generator{Z, typeof(a -> f(a...))})
isleaftype(R) && isleaftype(S) && return isleaftype(T) ? T : Any
return typejoin(R, S, T)
end
Expand Down
2 changes: 1 addition & 1 deletion base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1740,7 +1740,7 @@ broadcast{Tv1,Ti1,Tv2,Ti2}(f::Function, A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::Spar
broadcast!(f, spzeros(promote_type(Tv1, Tv2), promote_type(Ti1, Ti2), to_shape(broadcast_indices(A_1, A_2))), A_1, A_2)

@inline broadcast_zpreserving!(args...) = broadcast!(args...)
@inline broadcast_zpreserving(args...) = broadcast(args...)
@inline broadcast_zpreserving(args...) = Base.Broadcast.broadcast_elwise_op(args...)
broadcast_zpreserving{Tv1,Ti1,Tv2,Ti2}(f::Function, A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) =
broadcast_zpreserving!(f, spzeros(promote_type(Tv1, Tv2), promote_type(Ti1, Ti2), to_shape(broadcast_indices(A_1, A_2))), A_1, A_2)
broadcast_zpreserving{Tv,Ti}(f::Function, A_1::SparseMatrixCSC{Tv,Ti}, A_2::Union{Array,BitArray,Number}) =
Expand Down
19 changes: 17 additions & 2 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ import Base.Meta: isexpr
# PR 16988
@test Base.promote_op(+, Bool) === Int
@test isa(broadcast(+, [true]), Array{Int,1})
@test Base.promote_op(Float64, Bool) === Float64

# issue #17304
let foo = [[1,2,3],[4,5,6],[7,8,9]]
Expand All @@ -312,7 +311,7 @@ end
let f17314 = x -> x < 0 ? false : x
@test eltype(broadcast(f17314, 1:3)) === Int
@test eltype(broadcast(f17314, -1:1)) === Integer
@test eltype(broadcast(f17314, Int[])) === Any
@test eltype(broadcast(f17314, Int[])) === Union{Bool,Int}
end
let io = IOBuffer()
broadcast(x->print(io,x), 1:5) # broadcast with side effects
Expand All @@ -337,3 +336,19 @@ end
@test broadcast(+, 1.0, (0, -2.0)) == (1.0,-1.0)
@test broadcast(+, 1.0, (0, -2.0), [1]) == [2.0, 0.0]
@test broadcast(*, ["Hello"], ", ", ["World"], "!") == ["Hello, World!"]

# Ensure that even strange constructors that break `T(x)::T` work with broadcast
immutable StrangeType18623 end
StrangeType18623(x) = x
StrangeType18623(x,y) = (x,y)
@test @inferred broadcast(StrangeType18623, 1:3) == [1,2,3]
@test @inferred broadcast(StrangeType18623, 1:3, 4:6) == [(1,4),(2,5),(3,6)]

@test typeof(Int.(Number[1, 2, 3])) === typeof((x->Int(x)).(Number[1, 2, 3]))

@test @inferred broadcast(CartesianIndex, 1:2) == [CartesianIndex(1), CartesianIndex(2)]
@test @inferred broadcast(CartesianIndex, 1:2, 3:4) == [CartesianIndex(1,3), CartesianIndex(2,4)]

# Issue 18622
@test @inferred muladd.([1.0], [2.0], [3.0])::Vector{Float64} == [5.0]
@test @inferred tuple.(1:3, 4:6, 7:9)::Vector{Tuple{Int,Int,Int}} == [(1,4,7), (2,5,8), (3,6,9)]
12 changes: 6 additions & 6 deletions test/numbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2814,42 +2814,42 @@ let types = (Base.BitInteger_types..., BigInt, Bool,
Complex{Int}, Complex{UInt}, Complex32, Complex64, Complex128)
for S in types
for op in (+, -)
T = @inferred Base._promote_op(op, S)
T = @inferred Base.promote_op(op, S)
t = @inferred op(one(S))
@test T === typeof(t)
end

for R in types
for op in (+, -, *, /, ^)
T = @inferred Base._promote_op(op, S, R)
T = @inferred Base.promote_op(op, S, R)
t = @inferred op(one(S), one(R))
@test T === typeof(t)
end
end
end

@test @inferred(Base._promote_op(!, Bool)) === Bool
@test @inferred(Base.promote_op(!, Bool)) === Bool
end

let types = (Base.BitInteger_types..., BigInt, Bool,
Rational{Int}, Rational{BigInt},
Float16, Float32, Float64, BigFloat)
for S in types, T in types
for op in (<, >, <=, >=, (==))
@test @inferred(Base._promote_op(op, S, T)) === Bool
@test @inferred(Base.promote_op(op, S, T)) === Bool
end
end
end

let types = (Base.BitInteger_types..., BigInt, Bool)
for S in types
T = @inferred Base._promote_op(~, S)
T = @inferred Base.promote_op(~, S)
t = @inferred ~one(S)
@test T === typeof(t)

for R in types
for op in (&, |, <<, >>, (>>>), %, ÷)
T = @inferred Base._promote_op(op, S, R)
T = @inferred Base.promote_op(op, S, R)
t = @inferred op(one(S), one(R))
@test T === typeof(t)
end
Expand Down

0 comments on commit d16d994

Please sign in to comment.