From c4c8c2521438e943314e38482239234ec1340f12 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 12 Apr 2022 17:44:43 -0400 Subject: [PATCH 1/2] test for DiffEqFlux 699 --- test/destructure.jl | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test/destructure.jl b/test/destructure.jl index 043315b3..d2b368c7 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -180,3 +180,42 @@ end 4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],) end + +@testset "DiffEqFlux issue 699" begin + # The gradient of `re` is a vector into which we accumulate contributions, and the issue + # is that one contribution may have a wider type than `v`, especially for `Dual` numbers. + v, re = destructure((x=[1,2.0], y=[3,4,5.0])) + _, bk = Zygote.pullback(re, ones(5)) + # Testing with `Complex` isn't ideal, but this was an error on 0.2.1. + # If some upgrade inserts ProjectTo, this will fail, and can be changed: + @test bk((x=[1.0,im], y=nothing)) == ([1,im,0,0,0],) + + @test bk((x=nothing, y=[10,20,30]))[1] isa Vector{Float64} # despite some ZeroTangent + @test bk((x=nothing, y=nothing)) == (nothing,) # don't reduce over empty list of eltypes + @test bk((x=nothing, y=@thunk [1,2,3.0] .* 10)) == ([0,0,10,20,30],) +end + +#= + +# Adapted from https://github.com/SciML/DiffEqFlux.jl/pull/699#issuecomment-1092846657 +using ForwardDiff, Zygote, Flux, Optimisers, Test + +y = Float32[0.8564646, 0.21083355] +p = randn(Float32, 252); +t = 1.5f0 +λ = [ForwardDiff.Dual{ForwardDiff.Tag{Nothing,Float32}}(0.87135935, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), ForwardDiff.Dual{ForwardDiff.Tag{Nothing,Float32}}(1.5225363, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)] + +model = Chain(x -> x .^ 3, + Dense(2 => 50, tanh), + Dense(50 => 2)) + +p,re = Optimisers.destructure(model) +f(u, p, t) = re(p)(u) +_dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t)) +end +tmp1, tmp2 = back(λ); +tmp1 +@test tmp2 isa Vector{<:ForwardDiff.Dual} + +=# From 1e1fa3578129fd993970bcf70783750338eb1a5a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 12 Apr 2022 18:04:54 -0400 Subject: [PATCH 2/2] type widening for _grad --- src/destructure.jl | 17 ++++++++++++----- test/destructure.jl | 19 ++++++++++--------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index 2b91983d..d000ff75 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -127,15 +127,22 @@ function _grad!(x, dx, off, flat::AbstractVector) x′, _ = functor(typeof(x), x) dx′, _ = functor(typeof(x), base(dx)) off′, _ = functor(typeof(x), off) - foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) + for (xᵢ, dxᵢ, oᵢ) in zip(x′, dx′, off′) + flat = _grad!(xᵢ, dxᵢ, oᵢ, flat) + end flat end -function _grad!(x, dx, off::Integer, flat::AbstractVector) - @views flat[off .+ (1:length(x))] .+= vec(dx) # must visit all tied nodes +function _grad!(x, dx, off::Integer, flat::AbstractVector{T}) where T + dx_un = unthunk(dx) + T2 = promote_type(T, eltype(dx_un)) + if T != T2 # then we must widen the type + flat = copyto!(similar(flat, T2), flat) + end + @views flat[off .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes flat end -_grad!(x, dx::Zero, off, flat::AbstractVector) = dx -_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity +_grad!(x, dx::Zero, off, flat::AbstractVector) = flat +_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat # ambiguity # These are only needed for 2nd derivatives: function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) diff --git a/test/destructure.jl b/test/destructure.jl index d2b368c7..d20f4f30 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -184,15 +184,16 @@ end @testset "DiffEqFlux issue 699" begin # The gradient of `re` is a vector into which we accumulate contributions, and the issue # is that one contribution may have a wider type than `v`, especially for `Dual` numbers. - v, re = destructure((x=[1,2.0], y=[3,4,5.0])) - _, bk = Zygote.pullback(re, ones(5)) + v, re = destructure((x=Float32[1,2], y=Float32[3,4,5])) + _, bk = Zygote.pullback(re, ones(Float32, 5)) # Testing with `Complex` isn't ideal, but this was an error on 0.2.1. # If some upgrade inserts ProjectTo, this will fail, and can be changed: @test bk((x=[1.0,im], y=nothing)) == ([1,im,0,0,0],) - @test bk((x=nothing, y=[10,20,30]))[1] isa Vector{Float64} # despite some ZeroTangent - @test bk((x=nothing, y=nothing)) == (nothing,) # don't reduce over empty list of eltypes - @test bk((x=nothing, y=@thunk [1,2,3.0] .* 10)) == ([0,0,10,20,30],) + @test bk((x=nothing, y=[10,20,30]))[1] isa Vector{Float32} # despite some ZeroTangent + @test bk((x=nothing, y=nothing)) == ([0,0,0,0,0],) + @test bk((x=nothing, y=@thunk [1,2,3] .* 10.0)) == ([0,0,10,20,30],) + @test bk((x=[1.2, 3.4], y=Float32[5,6,7])) == ([1.2, 3.4, 5, 6, 7],) end #= @@ -201,13 +202,13 @@ end using ForwardDiff, Zygote, Flux, Optimisers, Test y = Float32[0.8564646, 0.21083355] -p = randn(Float32, 252); +p = randn(Float32, 27); t = 1.5f0 -λ = [ForwardDiff.Dual{ForwardDiff.Tag{Nothing,Float32}}(0.87135935, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), ForwardDiff.Dual{ForwardDiff.Tag{Nothing,Float32}}(1.5225363, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)] +λ = [ForwardDiff.Dual(0.87135935, 1, 0, 0, 0, 0, 0), ForwardDiff.Dual(1.5225363, 0, 1, 0, 0, 0, 0)] model = Chain(x -> x .^ 3, - Dense(2 => 50, tanh), - Dense(50 => 2)) + Dense(2 => 5, tanh), + Dense(5 => 2)) p,re = Optimisers.destructure(model) f(u, p, t) = re(p)(u)