From ab984a5b44ddb1b9af45ea66f2b8cb01df7f25a0 Mon Sep 17 00:00:00 2001 From: Andreas Noack Date: Fri, 30 Dec 2016 12:37:33 -0500 Subject: [PATCH] Use containertype to determine array type for array broadcast (#19745) --- base/broadcast.jl | 13 +------------ test/broadcast.jl | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index a5163e3cd181b..a4cec47b4dd6d 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -11,20 +11,9 @@ export bitbroadcast, dotview export broadcast_getindex, broadcast_setindex! ## Broadcasting utilities ## - -broadcast_array_type() = Array -broadcast_array_type(A, As...) = - if is_nullable_array(A) || broadcast_array_type(As...) === Array{Nullable} - Array{Nullable} - else - Array - end - # fallbacks for some special cases @inline broadcast(f, x::Number...) = f(x...) @inline broadcast{N}(f, t::NTuple{N}, ts::Vararg{NTuple{N}}) = map(f, t, ts...) -@inline broadcast(f, As::AbstractArray...) = - broadcast_c(f, broadcast_array_type(As...), As...) # special cases for "X .= ..." (broadcast!) assignments broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x) @@ -313,7 +302,7 @@ ziptype{T}(::Type{T}, A) = typestuple(T, A) ziptype{T}(::Type{T}, A, B) = (Base.@_pure_meta; Iterators.Zip2{typestuple(T, A), typestuple(T, B)}) @inline ziptype{T}(::Type{T}, A, B, C, D...) = Iterators.Zip{typestuple(T, A), ziptype(T, B, C, D...)} -_broadcast_type{S}(::Type{S}, f, T::Type, As...) = Base._return_type(S, typestuple(S, T, As...)) +_broadcast_type{S}(::Type{S}, f, T::Type, As...) = Base._return_type(f, typestuple(S, T, As...)) _broadcast_type{T}(::Type{T}, f, A, Bs...) = Base._default_eltype(Base.Generator{ziptype(T, A, Bs...), ftype(f, A, Bs...)}) # broadcast methods that dispatch on the type of the final container diff --git a/test/broadcast.jl b/test/broadcast.jl index 314a94844c2bf..d92b3991e4760 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -363,7 +363,7 @@ StrangeType18623(x,y) = (x,y) let f(A, n) = broadcast(x -> +(x, n), A) @test @inferred(f([1.0], 1)) == [2.0] - g() = (a = 1; Base.Broadcast._broadcast_type(x -> x + a, 1.0)) + g() = (a = 1; Base.Broadcast._broadcast_type(Any, x -> x + a, 1.0)) @test @inferred(g()) === Float64 end @@ -376,3 +376,36 @@ end # Check that broadcast!(f, A) populates A via independent calls to f (#12277, #19722). @test let z = 1; A = broadcast!(() -> z += 1, zeros(2)); A[1] != A[2]; end + +# broadcasting for custom AbstractArray +immutable Array19745{T,N} <: AbstractArray{T,N} + data::Array{T,N} +end +Base.getindex(A::Array19745, i::Integer...) = A.data[i...] +Base.size(A::Array19745) = size(A.data) + +Base.Broadcast.containertype{T<:Array19745}(::Type{T}) = Array19745 + +Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array19745}) = Array19745 +Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array}) = Array19745 +Base.Broadcast.promote_containertype(::Type{Array19745}, ct) = Array19745 +Base.Broadcast.promote_containertype(::Type{Array}, ::Type{Array19745}) = Array19745 +Base.Broadcast.promote_containertype(ct, ::Type{Array19745}) = Array19745 + +Base.Broadcast.broadcast_indices(::Type{Array19745}, A) = indices(A) +Base.Broadcast.broadcast_indices(::Type{Array19745}, A::Ref) = () + +getfield19745(x::Array19745) = x.data +getfield19745(x) = x + +Base.Broadcast.broadcast_c(f, ::Type{Array19745}, A, Bs...) = + Array19745(Base.Broadcast.broadcast_c(f, Array, getfield19745(A), map(getfield19745, Bs)...)) + +@testset "broadcasting for custom AbstractArray" begin + a = randn(10) + aa = Array19745(a) + @test a .+ 1 == @inferred(aa .+ 1) + @test a .* a' == @inferred(aa .* aa') + @test isa(aa .+ 1, Array19745) + @test isa(aa .* aa', Array19745) +end