diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 504ef614d..52da0dad7 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -281,72 +281,62 @@ end @inline function broadcast_forward(f, args::Vararg{Any,N}) where N out = dual_function(f).(args...) T = eltype(out) - T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing) - if any(eltype(a) <: Complex for a in args) - _broadcast_forward_complex(T, out, args...) + if !isconcretetype(T) || T <: Union{Dual, Complex{<:Dual}} + if any(eltype(a) <: Complex for a in args) + return _broadcast_forward_complex(out, args...) + else + return _broadcast_forward(out, args...) + end else - _broadcast_forward(T, out, args...) + return (out, _ -> nothing) end end -# Real input and real output pullback -@inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} +# Real input +@inline _extract_value(x) = value(x) +@inline _extract_value(x::Complex) = Complex(value(real(x)), value(imag(x))) +@inline _broadcast_scalar_pullback(ȳ, out, i) = ȳ * partials(out, i) +@inline function _broadcast_scalar_pullback(ȳ, out::Complex, i) + return real(ȳ) * partials(real(out), i) + imag(ȳ) * partials(imag(out), i) +end +@inline function _broadcast_forward(out, args::Vararg{Any, N}) where {N} valN = Val(N) - y = broadcast(x -> value(x), out) + y = broadcast(x -> _extract_value(x), out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out)) + unbroadcast(args[i], + broadcast((y1, o1) -> _broadcast_scalar_pullback(y1, o1, i), ȳ, out) + ) end (nothing, nothing, dargs...) # nothings for broadcasted & f end return y, bc_fwd_back end -# This handles the complex output and real input pullback -@inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} - valN = Val(N) - y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out) - function bc_fwd_back(ȳ) - dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out)) - end - (nothing, nothing, dargs...) # nothings for broadcasted & f - end - return y, bc_fwd_back - end - # This handles complex input and real output. We use the gradient definition from ChainRules here # since it agrees with what Zygote did for real(x). -@inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} - valN = Val(N) - y = broadcast(x -> value(x), out) - function bc_fwd_back(ȳ) - dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(partials(o1, i), partials(o1, i+N)), ȳ, out)) - end - (nothing, nothing, dargs...) # nothings for broadcasted & f - end - return y, bc_fwd_back +@inline function _broadcast_scalar_pullback_complex(N, Δz, df, i) + return Δz * Complex(partials(df, i), partials(df, i + N)) end - # # # This is for complex input and complex output # If we assume that # f(x + iy) = u(x,y) + iv(x,y) # then we do the following for the adjoint # Δu ∂u/∂x + Δv∂v/∂x + i(Δu∂u/∂y + Δv ∂v/∂y ) # this follows https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html -function _adjoint_complex(N, Δz, df, i) - Δu, Δv = reim(Δz) - du, dv = reim(df) - return Complex(Δu*partials(du, i) + Δv*partials(dv, i), Δu*partials(du, i+N) + Δv*partials(dv, i+N)) +@inline function _broadcast_scalar_pullback_complex(N, Δz, df::Complex, i) + Δu, Δv = reim(Δz) + du, dv = reim(df) + return Complex(Δu * partials(du, i) + Δv * partials(dv, i), Δu * partials(du, i + N) + Δv * partials(dv, i + N)) end - -@inline function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} +@inline function _broadcast_forward_complex(out, args::Vararg{Any, N}) where {N} valN = Val(N) - y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out) + y = broadcast(x -> _extract_value(x), out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(N, y1, o1, i), ȳ, out)) + unbroadcast(args[i], + broadcast((y1, o1) -> _broadcast_scalar_pullback_complex(N, y1, o1, i), ȳ, out) + ) end (nothing, nothing, dargs...) # nothings for broadcasted & f end diff --git a/test/features.jl b/test/features.jl index 908ae5815..25f6661b4 100644 --- a/test/features.jl +++ b/test/features.jl @@ -798,6 +798,29 @@ end @test gradient(xs -> sum(map((x -> x<2 ? false : x^2), xs)), [1,2,3])[1][2:3] == [4, 6] @test gradient(xs -> mapreduce((x -> x<2 ? false : x^2), +, xs), [1,2,3])[1][2:3] == [4, 6] + # https://github.com/FluxML/Zygote.jl/issues/1439 + # type stable forward pass with given input, but type unstable with dualized input + # Real input, real output + f = x -> x > 1.0 ? 1.0 : x^2 + @test gradient(xs -> sum(f.(xs)), [0.5, 1.0, 1.5])[1] == [1.0, 2.0, 0.0] + # Real input, complex output + f = x -> x > 1.0 ? 1.0im : (x + 1.0im)^2 + @test gradient(xs -> sum(abs2, f.(xs)), [0.5, 1.0, 1.5])[1] == [2.5, 8.0, 0.0] + # Complex input, complex output + f = x -> imag(x) > 1.0 ? 1.0im : x^2 + @test gradient(xs -> sum(abs2, f.(xs)), [0.5im, 1.0im, 1.5im])[1] == [ + 0.0 + 0.5im, 0.0 + 4.0im, 0.0 + 0.0im + ] + # Complex input, real output + f = x -> imag(x) > 1.0 ? 1.0 : abs2(x) + @test gradient(xs -> sum(abs2, f.(xs)), [0.5im, 1.0im, 1.5im])[1] == [ + 0.0 + 0.5im, 0.0 + 4.0im, 0.0 + 0.0im + ] + # Slightly more complex case that used to error + f = x -> x > 1.0 ? 1.0 : x^2 + g = x -> sum(repeat(x, inner=2) .* f.(repeat(x, inner=2))) + @test gradient(g, [0.5, 1.0, 1.5])[1] == [1.5, 6.0, 2.0] + # with Ref, Val, Symbol @test gradient(x -> sum(x .+ Ref(x[1])), [1,2,3]) == ([4,1,1],) @test gradient(x -> sum(x .+ (x[1],)), [1,2,3]) == ([4,1,1],)