diff --git a/Project.toml b/Project.toml index e4309a1a..cbfed503 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StaticArrays" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.4" +version = "1.5.5" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/broadcast.jl b/src/broadcast.jl index 4080a7f2..0240060c 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -58,7 +58,7 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo{1},Vararg{SOneTo{1}}}) = static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = () # copy overload @inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M - flat = Broadcast.flatten(B); as = flat.args; f = flat.f + flat = broadcast_flatten(B); as = flat.args; f = flat.f argsizes = broadcast_sizes(as...) ax = axes(B) ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.") @@ -68,7 +68,7 @@ end @inline Base.copyto!(dest, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B) @inline Base.copyto!(dest::AbstractArray, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B) @inline function _copyto!(dest, B::Broadcasted{StaticArrayStyle{M}}) where M - flat = Broadcast.flatten(B); as = flat.args; f = flat.f + flat = broadcast_flatten(B); as = flat.args; f = flat.f argsizes = broadcast_sizes(as...) ax = axes(B) if ax isa Tuple{Vararg{SOneTo}} @@ -165,3 +165,68 @@ end return dest end end + +# Work around for https://github.com/JuliaLang/julia/issues/27988 +# The following code is borrowed from https://github.com/JuliaLang/julia/pull/43322 +# with some modification to make it also works on 1.6. +# TODO: make `broadcast_flatten` call `Broadcast.flatten` once julia#43322 is merged. +module StableFlatten + +export broadcast_flatten + +using Base: tail +using Base.Broadcast: isflat, Broadcasted + +maybeconstructor(f) = f +maybeconstructor(::Type{F}) where {F} = (args...; kwargs...) -> F(args...; kwargs...) + +function broadcast_flatten(bc::Broadcasted{Style}) where {Style} + isflat(bc) && return bc + args = cat_nested(bc) + len = Val{length(args)}() + makeargs = make_makeargs(bc.args, len, ntuple(_->true, len)) + f = maybeconstructor(bc.f) + @inline newf(args...) = f(prepare_args(makeargs, args)...) + return Broadcasted{Style}(newf, args, bc.axes) +end + +cat_nested(bc::Broadcasted) = cat_nested_args(bc.args) +cat_nested_args(::Tuple{}) = () +cat_nested_args(t::Tuple) = (cat_nested(t[1])..., cat_nested_args(tail(t))...) +cat_nested(@nospecialize(a)) = (a,) + +function make_makeargs(args::Tuple, len, flags) + makeargs, r = _make_makeargs(args, len, flags) + r isa Tuple{} || error("Internal error. Please file a bug") + return makeargs +end + +# We build `makeargs` by traversing the broadcast nodes recursively. +# note: `len` isa `Val` indicates the length of whole flattened argument list. +# `flags` is a tuple of `Bool` with the same length of the rest arguments. +@inline function _make_makeargs(args::Tuple, len::Val, flags::Tuple) + head, flags′ = _make_makeargs1(args[1], len, flags) + rest, flags″ = _make_makeargs(tail(args), len, flags′) + (head, rest...), flags″ +end +_make_makeargs(::Tuple{}, ::Val, x::Tuple) = (), x + +# For flat nodes: +# 1. we just consume one argument, and return the "pick" function +@inline function _make_makeargs1(@nospecialize(a), ::Val{N}, flags::Tuple) where {N} + pickargs(::Val{N}) where {N} = (@nospecialize(x::Tuple)) -> x[N] + return pickargs(Val{N-length(flags)+1}()), tail(flags) +end + +# For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc))) +@inline function _make_makeargs1(bc::Broadcasted, len::Val, flags::Tuple) + makeargs, flags′ = _make_makeargs(bc.args, len, flags) + f = maybeconstructor(bc.f) + @inline makeargs1(@nospecialize(args::Tuple)) = f(prepare_args(makeargs, args)...) + makeargs1, flags′ +end + +prepare_args(::Tuple{}, @nospecialize(::Tuple)) = () +@inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) = (makeargs[1](x), prepare_args(tail(makeargs), x)...) +end +using .StableFlatten diff --git a/test/broadcast.jl b/test/broadcast.jl index b1202df4..ccb90b51 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -335,3 +335,20 @@ end @test @inferred(Broadcast.instantiate(f(a; ax))).axes isa Tuple{SOneTo,SOneTo,Base.OneTo} @test @inferred(Broadcast.instantiate(f(a; ax = ax[1:2]))).axes isa NTuple{2,SOneTo} end + +@testset "`broadcast`'s stability" begin + issue1078(t) = t ./ (1 .- t .^ 2) + a = @SVector rand(3) + @test @inferred(issue1078(a)) == issue1078(Vector(a)) + issue560(ũ, u₀, u₁, ρ) = ũ ./ (1e-6 .+ max.(abs.(u₀), abs.(u₁)) .* ρ) + issue797(a, b, c, d) = @. a + 5 * b + 3 * c - d + manual(a, b, c, d) = @. 0.1a^2 + 0.2b^3 * 0.4c^1 + 0.5d + manual2(a, b, c, d) = @. Float32(a) * Float32(b) + Float32(c) * Float32(d) + args = rand(3), rand(3), rand(3), rand(3) + @test @inferred(issue560(map(SVector{3}, args)...)) == issue560(args...) + @test @inferred(issue797(map(SVector{3}, args)...)) == issue797(args...) + @test @inferred(manual(map(SVector{3}, args)...)) == manual(args...) + @test @inferred(manual2(map(SVector{3}, args)...)) == manual2(args...) + issue609(s, c::Integer) = (s .- s.^2) ./ c + @test @inferred(issue609(SA[1.], 2)) == issue609([1.], 2) +end