-
-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Potential gradient issues with Flux chains when changing parameter type #533
Comments
Here's a simplification: using DiffEqFlux, Flux, NeuralPDE, ModelingToolkit, DomainSets
@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
# 2D PDE
eq = Dxx(u(x,y)) + Dyy(u(x,y)) ~ -sin(pi*x)*sin(pi*y)
# Initial and boundary conditions
bcs = [u(0,y) ~ 0.0, u(1,y) ~ -sin(pi*1)*sin(pi*y),
u(x,0) ~ 0.0, u(x,1) ~ -sin(pi*x)*sin(pi*1)]
# Space and time domains
domains = [x ∈ Interval(0.0,1.0),
y ∈ Interval(0.0,1.0)]
@named pde_system = PDESystem(eq,bcs,domains,[x,y],[u(x, y)])
fastchain = FastChain(FastDense(2,12,Flux.σ),FastDense(12,12,Flux.σ),FastDense(12,1))
fluxchain = Chain(Dense(2,12,Flux.σ),Dense(12,12,Flux.σ),Dense(12,1))
initθ = Float64.(DiffEqFlux.initial_params(fastchain))
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ)
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
sym_prob = NeuralPDE.symbolic_discretize(pde_system,discretization1)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
Zygote.gradient((x)->prob2.f(x,nothing),initθ) # Very very different???
## Fixed
initθ = DiffEqFlux.initial_params(fastchain)
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ,
phi = (x,p)->re(p)(x))
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
sym_prob = NeuralPDE.symbolic_discretize(pde_system,discretization1)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
Zygote.gradient((x)->prob2.f(x,nothing),initθ) # it's fine now! Notice that it has incorrect gradients with |
I thought it would be a simple using DiffEqFlux, Flux, Adapt
fastchain = FastChain(FastDense(2,12,Flux.σ),FastDense(12,12,Flux.σ),FastDense(12,1))
fluxchain = Chain(Dense(2,12,Flux.σ),Dense(12,12,Flux.σ),Dense(12,1))
initθ = DiffEqFlux.initial_params(fastchain)
p,re = Flux.destructure(fluxchain)
x = Float32[1.5,0.5]
dx1,dp1 = Zygote.gradient((x,p)->sum(fastchain(adapt(Array,x),p)),x,initθ)
dx2,dp2 = Zygote.gradient((x,p)->sum(re(p)(adapt(Array,x))),x,initθ)
dx1 ≈ dx2 # true
dp1 ≈ dp2 # true
initθ = Float64.(DiffEqFlux.initial_params(fastchain))
x = Float64[1.5,0.5]
dx3,dp3 = Zygote.gradient((x,p)->sum(fastchain(x,p)),x,initθ)
dx4,dp4 = Zygote.gradient((x,p)->sum(re(p)(x)),x,initθ)
dx3 ≈ dx1 # true
dx4 ≈ dx1 # true
dp3 ≈ dp1 # true
dp4 ≈ dp1 # true |
But it goes away if I do using DiffEqFlux, Flux, NeuralPDE, ModelingToolkit, DomainSets
@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
# 2D PDE
eq = Dxx(u(x,y)) + Dyy(u(x,y)) ~ -sin(pi*x)*sin(pi*y)
# Initial and boundary conditions
bcs = [u(0,y) ~ 0.0, u(1,y) ~ -sin(pi*1)*sin(pi*y),
u(x,0) ~ 0.0, u(x,1) ~ -sin(pi*x)*sin(pi*1)]
# Space and time domains
domains = [x ∈ Interval(0.0,1.0),
y ∈ Interval(0.0,1.0)]
@named pde_system = PDESystem(eq,bcs,domains,[x,y],[u(x, y)])
fastchain = FastChain(FastDense(2,12,Flux.σ),FastDense(12,12,Flux.σ),FastDense(12,1))
fluxchain = Chain(Dense(2,12,Flux.σ),Dense(12,12,Flux.σ),Dense(12,1)) |> f64
initθ = Float64.(DiffEqFlux.initial_params(fastchain))
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ)
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
Zygote.gradient((x)->prob2.f(x,nothing),initθ) # it's fine now! |
We worked around it here by just changing the number type ourselves so NeuralPDE is safe, but @CarloLucibello @mcabbott this is a pretty dangerous bug to have lurking around. Have you considered merging @DhairyaLGandhi 's branch https://github.com/FluxML/Optimisers.jl/tree/dg/noproject and adding a test to catch this in the future? |
Have not tried to reproduce this, but this change FluxML/Optimisers.jl@9c61c8a looks like it ought to allow you to make a MWE, or at least to figure out what types are actually involved here. It does not look safe to merge. |
My attempts at an MWE failed, but maybe @DhairyaLGandhi found a nicer one. I think you need a map right after the restructure or something, but 🤷 my isolations all worked, so it's something rather specific. |
Can you tell me what this prints, and also post the stacktrace somewhere? julia> using Optimisers
julia> @eval Optimisers begin
function _getat(y::AbstractArray, o::Int, flat::AbstractVector)
res = ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y)))
if eltype(res) != eltype(y)
@info "found one" summary(y) summmary(flat) summary(res)
end
res
end
end
_getat (generic function with 2 methods)
julia> using DiffEqFlux, Flux, NeuralPDE, ModelingToolkit, DomainSets So far my attempt to install everything Master:
1.7:
|
That skipping precompilation is something that shows up with any packages on v1.7 if you update and reuse without restarting the REPL. |
@DhairyaLGandhi identified the right spot, but his fix is incorrect. Here's a deterministic example using DiffEqFlux, Flux, NeuralPDE, ModelingToolkit, DomainSets
@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
# 2D PDE
eq = Dxx(u(x,y)) + Dyy(u(x,y)) ~ -sin(pi*x)*sin(pi*y)
# Initial and boundary conditions
bcs = [u(0,y) ~ 0.0, u(1,y) ~ -sin(pi*1)*sin(pi*y),
u(x,0) ~ 0.0, u(x,1) ~ -sin(pi*x)*sin(pi*1)]
# Space and time domains
domains = [x ∈ Interval(0.0,1.0),
y ∈ Interval(0.0,1.0)]
@named pde_system = PDESystem(eq,bcs,domains,[x,y],[u(x, y)])
fastchain = FastChain(FastDense(2,12,Flux.σ),FastDense(12,12,Flux.σ),FastDense(12,1))
fluxchain = Chain(Dense(2,12,Flux.σ),Dense(12,12,Flux.σ),Dense(12,1))
initθ = range(0,1,length=205)
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ)
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# ([0.34197112172842026, -3.558639347553253, 4.5405376851558685, 0.6394826173782349, 0.7381760030984879, -3.1634015142917633, -7.0652690678834915, 17.032554239034653, -6.869950324296951, -14.772801548242569 … -413.50527000427246, 98.54232025146484, 610.5882167816162, 98.63247489929199, 98.6751537322998, 98.71630668640137, 610.7559909820557, 610.7942523956299, 98.83114624023438, 99.81404876708984],)
## Fixed
initθ = Float32.(range(0,1,length=205))
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ)
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# (Float32[0.34133404, 0.440552, 0.5395764, 0.63829243, 0.73679507, 0.83497345, 0.9328923, 1.0304716, 1.1278142, 1.2247037 … 98.49864, 98.546234, 98.59311, 98.63931, 98.6693, 98.71631, 98.76674, 98.80206, 98.82138, 99.81404],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# (Float32[0.3413074, 0.44055194, 0.5395802, 0.6383001, 0.7367189, 0.83488965, 0.9327855, 1.0305097, 1.127738, 1.2246275 … 98.49864, 98.546234, 98.59311, 98.63931, 98.6693, 98.71631, 98.76674, 98.80206, 98.82138, 99.81404],)
# Doesn't do anything:
@eval Optimisers begin
function _getat(y::AbstractArray, o::Int, flat::AbstractVector)
res = ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y)))
if eltype(res) != eltype(y)
@info "found one" summary(y) summmary(flat) summary(res)
end
res
end
end
initθ = range(0,1,length=205)
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ)
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# ([0.34197112172842026, -3.558639347553253, 4.5405376851558685, 0.6394826173782349, 0.7381760030984879, -3.1634015142917633, -7.0652690678834915, 17.032554239034653, -6.869950324296951, -14.772801548242569 … -413.50527000427246, 98.54232025146484, 610.5882167816162, 98.63247489929199, 98.6751537322998, 98.71630668640137, 610.7559909820557, 610.7942523956299, 98.83114624023438, 99.81404876708984],)
# Doesn't do anything
using Optimisers
@eval Optimisers begin
_getat(y::AbstractArray{T}, o::Int, flat::AbstractVector) where T =
T.(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
end
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# ([0.34197112172842026, -3.558639347553253, 4.5405376851558685, 0.6394826173782349, 0.7381760030984879, -3.1634015142917633, -7.0652690678834915, 17.032554239034653, -6.869950324296951, -14.772801548242569 … -413.50527000427246, 98.54232025146484, 610.5882167816162, 98.63247489929199, 98.6751537322998, 98.71630668640137, 610.7559909820557, 610.7942523956299, 98.83114624023438, 99.81404876708984],)
# Fixed!!! ?
using Optimisers
@eval Optimisers begin
_getat(y::AbstractArray{T}, o::Int, flat::AbstractVector) where T =
@show Float64.(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
end
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# ([0.34135578866824007, 0.44055960533039673, 0.539547040771544, 0.6382977315327074, 0.736791544076749, 0.8350087706213394, 0.932929464288907, 1.030534171090762, 1.127803417506088, 1.2247179504031824 … 98.49443127820058, 98.54203022762303, 98.58792582296412, 98.63217888927002, 98.67485975570594, 98.7160397286769, 98.75572534343029, 98.79398334943609, 98.83090070900344, 99.81405020016672],)
using Optimisers
@eval Optimisers begin
_getat(y::AbstractArray{T}, o::Int, flat::AbstractVector) where T =
reshape(flat[o .+ (1:length(y))], axes(y))
end
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# ([0.34135578866824007, 0.44055960533039673, 0.539547040771544, 0.6382977315327074, 0.736791544076749, 0.8350087706213394, 0.932929464288907, 1.030534171090762, 1.127803417506088, 1.2247179504031824 … 98.49443127820058, 98.54203022762303, 98.58792582296412, 98.63217888927002, 98.67485975570594, 98.7160397286769, 98.75572534343029, 98.79398334943609, 98.83090070900344, 99.81405020016672],)
using Optimisers
@eval Optimisers begin
function _getat(y::AbstractArray{T}, o::Int, flat::AbstractVector) where T
@show eltype(y), eltype(flat)
reshape(flat[o .+ (1:length(y))], axes(y))
end
end
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# (eltype(y), eltype(flat)) = (Float32, Float64)
# ([0.34135578866824007, 0.44055960533039673, 0.539547040771544, 0.6382977315327074, 0.736791544076749, 0.8350087706213394, 0.932929464288907, 1.030534171090762, 1.127803417506088, 1.2247179504031824 … 98.49443127820058, 98.54203022762303, 98.58792582296412, 98.63217888927002, 98.67485975570594, 98.7160397286769, 98.75572534343029, 98.79398334943609, 98.83090070900344, 99.81405020016672],) Instead of going down to What's the reason for that |
This is what's expected when the primal is Float32 for this variable If something else wants a Float64 gradient for a Float32 variable, then maybe that's the problem. |
I understand why it exists now, but I don't understand why the type of Basically, why wouldn't |
Because its one job is reconstruction? It's explicitly designed to allow for mixed precisions, and not forget this. And not just precisions, the help example is:
No, it does not.
Does it? What issues? |
Well, that's why I'm confused why it's doing more than just reconstruction. It's not just reconstructing the array julia> re([3, 5-im, 7+11im])
(x = [3+0*im, 5-im], y = (sin, [7 + 11im])) i.e. it would be "the same form" but with the values of
No, I'm saying I would've expected: julia> re(cu([3, 5-im, 7+11im]))
(x = cu([3+0*im, 5-im]), y = (sin, cu([7 + 11im]))) "It's the same as the destructured thing, but with the values taken from julia> re(Tracked([3, 5-im, 7+11im]))
(x = Tracked([3+0*im, 5-im]), y = (sin, Tracked([7 + 11im]))) But anyways, now I'm worried about that complex case: that should definitely be counted as a bug IMO, or throw an error. |
Right, It is actually a related issue issue as why we expect some custom leaf types to return structures as adjoints in the backpass and therefore need to reconstruct the type as opposed to operating on the fields directly. |
The fundamental issue here is that
Chris has already described why no type conversion during reconstruction makes sense for 1). For 2), I recall we went down this path after many issues where users were expecting the following invariants to be upheld: The complex example is actually a great one because it shows how these can be broken without type conversion. If you pass in a I think the only way to resolve this tension is to bifurcate the |
For use as 1, is it really too much to ask that you make the "template" model with the desired number type? That seems like a simpler, easier-to-understand API. Rather than having some special additional mode to know about, document & test. At present it |
If that's the case it should probably throw an error instead of computing incorrect gradients. The complex number case would be a particularly nasty one to try and debug. Even finding this behavior took a long time. |
But there are no incorrect gradients here. Like everything else using ChainRules, these are Your complaint is, if I understand right, entirely that |
And that makes it surprising.
This whole discussion has been about 2. The issue is that type conversion presents itself as incorrect gradients in the case of (2). A calculation which says "I want to use Complex{Float64}" will silently use Float32, returning 0's for the complex values and computing with incorrect precision on the real parts. The only way this is exposed to the user if one checks the gradient calculation (something that isn't a user level property anyways, so it's actually just hidden as "it didn't train"). One issue here is that it only even represents itself as existing in the forward pass in isolation. Here we do things like Look back at FluxML/Flux.jl#1901 (comment) . Now that I've finally isolated this 5 months later, I realized that this behavior change is what caused the downstream tests to fail, sometimes, depending on the Optimisers version that was received. The precision change caused there to be a higher probability for test failure (since it was still random initializations), so the tests actually found it, but run enough times and the last one was green, it looked like a fluke. Almost imperceptibly the behavior just was "things are a little bit more janky these days, nobody really knows why" until I finally got it isolated as just that the gradient precision was different from the precision that was specified. It might now be clearly documented, but this is very easy to accidentally hit, and very hard to diagnose unless you already know it's possible. Multiple people took a look and no one realized that passing Float64's around isn't a good idea if you forget
I don't see why that should be the case. |
Great, well now that the documentation is clear, no need to guess.
Again, no incorrect gradients have been exhibited here.
The way you say this is by making the primal complex. Real numbers may not have complex gradients. The answer to "which way is uphill from the summit of this mountain?" cannot sensibly be a complex vector. Allowing that was a source of truly mysterious gradient bugs.
As you know, the old
If you think this ought to be yet more strict, please make an issue. |
Continuing this discussion upstream. |
MWE:
See fluxchain fails and the gradient is off.
The text was updated successfully, but these errors were encountered: