Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme internal error while running neural ODE with Lux + Enzyme #2110

Open
heyyeahcrow opened this issue Nov 20, 2024 · 10 comments
Open

Enzyme internal error while running neural ODE with Lux + Enzyme #2110

heyyeahcrow opened this issue Nov 20, 2024 · 10 comments

Comments

@heyyeahcrow
Copy link

Hi,

I tried to use AutoEnzyme as an optimizer to build a neural network to predict parameters for ODEs following the example of DiffEqFlux, but it turned out to return an Enzyme internal error and a bunch of LLVM computations.

using Lux, DiffEqFlux, OrdinaryDiffEq, Plots, Printf, Statistics
using ComponentArrays
using Optimization, OptimizationOptimisers
using Enzyme
using Dates
using Random
using StaticArrays

function evolve!(dc, c, p, t)
    p1 = p[1]
    p2 = p[2]
    dc .= c .* p2 * p1
end

function simulate(i1, i2, a, b, t_span)
    p2 = exp(-i2 * a)
    p1 = i1 * b
    p = (p1, p2)
    c0 = [1.0 0.0; 1.0 0.0]
    prob = ODEProblem(evolve!, c0, t_span, p)
    sol = solve(prob, Euler(), save_everystep=false, dt = 0.5)
    return Array(sol[end])
end


rng = Xoshiro(0)
b = [0.0 1.0; 1.0 0.0]
a = 0.6
n = length(b[1, :])
i1 = 0.18
i2 = 2.5 
timespan = (0.0, 5.0)
ans = simulate(i1, i2, a, b, timespan)

display(ans)

inputs = [i1, i2]
input_size = length(inputs)
output_size = length(a) + length(b)
nn = Chain(
    Dense(input_size, input_size*3*n, tanh),
    Dense(input_size*3*n, output_size*2, tanh),
    Dense(output_size*2, output_size, sigmoid)
)

u, st = Lux.setup(rng, nn)

function predict_neuralode(u)
    # Get parameters from the neural network
    output, outst = nn(inputs, u, st)
    # Segregate the output
    p_a = output[1]
    pp_b = output[length(a)+1:end]
    p_b = zeros(n, n)
    index = 1
    for i in 1:n
        for j in 1:n
            p_b[i, j] = pp_b[index]
            index += 1
        end
    end
    nn_output = [p_a, p_b]
    println("nn_output: ", nn_output)
    pred = simulate(i1, i2, p_a, p_b, timespan)
    return Array(pred)
end

function loss_neuralode(ans, u)
    pred = predict_neuralode(u)
    loss = sum(abs2, ans .- pred)
    return loss, pred
end

loss, pred = loss_neuralode(ans, u)

loss_values = Float64[]
callback = function (p, l, pred; doplot = false)
    println(l)
    push!(loss_values, l)
end

pinit = ComponentArray(u)
callback(pinit, loss_neuralode(ans, pinit)...)

adtype = Optimization.AutoEnzyme()

optf = Optimization.OptimizationFunction((u,_) -> loss_neuralode(ans, u), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)

result_neuralode = Optimization.solve(
    optprob, OptimizationOptimisers.Adam(0.02); callback = callback, maxiters = 50)

The error log: error_log_2024-11-19_15-56-09.txt
The stacktrace: Stacktrace.txt

I also tried to run the example by replacing the Zygote and AutoZygote with Enzyme and AutoEnzyme, but it still returned the same error. They happened on both Mac and Windows systems.

Julia Version 1.11.1
Packages:
[b0b7db55] ComponentArrays v0.15.18
[aae7a2af] DiffEqFlux v4.1.0
[7da242da] Enzyme v0.13.14
⌅ [d9f16b24] Functors v0.4.12
[e6f89c97] LoggingExtras v1.1.0
[b2108857] Lux v1.2.3
[7f7a1694] Optimization v4.0.5
[42dfb2eb] OptimizationOptimisers v0.3.4
[1dea7af3] OrdinaryDiffEq v6.90.1
[91a5bcdd] Plots v1.40.9
[90137ffa] StaticArrays v1.9.8
[10745b16] Statistics v1.11.1
[e88e6eb3] Zygote v0.6.73
[ade2ca70] Dates v1.11.0
[56ddb016] Logging v1.11.0
[de0858da] Printf v1.11.0
[9a3f8284] Random v1.11.0

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2024

Hi, this looks like an error in tje 1.11 FFI call support in Enzyme. Two quick things:

  1. Can you test on the latest version of enzyme (I think this ought be fixed, if not we should fix it)
  2. Can you make a reproducer that only has a direct autodiff call? @ChrisRackauckas may be able to help you with this

@heyyeahcrow
Copy link
Author

Hi, this looks like an error in tje 1.11 FFI call support in Enzyme. Two quick things:

  1. Can you test on the latest version of enzyme (I think this ought be fixed, if not we should fix it)
  2. Can you make a reproducer that only has a direct autodiff call? @ChrisRackauckas may be able to help you with this

I updated it and still showed the same error. I'm currently trying the second path.

BTW, do I need to vectorize all the inputs and outputs of my ODE?

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2024

I don't think that should be needed here to get it to fail, but that's just an intuition

@heyyeahcrow
Copy link
Author

heyyeahcrow commented Nov 25, 2024

@wsmoses

Hi, this looks like an error in tje 1.11 FFI call support in Enzyme. Two quick things:

  1. Can you test on the latest version of enzyme (I think this ought be fixed, if not we should fix it)
  2. Can you make a reproducer that only has a direct autodiff call? @ChrisRackauckas may be able to help you with this

I tried to make the reproducer run an autodiff call for the loss function, but it came up with a similar error. (A bunch of LLVM calculation and the similar stacktrace)
error_log2.txt

using Lux, DiffEqFlux, OrdinaryDiffEq, Plots, Printf, Statistics
using ComponentArrays
using Optimization, OptimizationOptimisers
#using Optimisers
using Enzyme
using Dates
using Random
using StaticArrays


using Dates
timestamp = Dates.format(now(), "yyyy-mm-dd_HH-MM-SS")

function evolve!(dc, c, p, t)
    p1 = p[1]
    p2 = p[2]
    dc .= c .* p2 * p1
end

function simulate(i1, i2, a, b, t_span)
    p2 = exp(-i2 * a)
    p1 = i1 * b
    p = (p1, p2)
    c0 = [1.0 0.0; 1.0 0.0]
    prob = ODEProblem(evolve!, c0, t_span, p)
    sol = solve(prob, Euler(), save_everystep=false, dt = 0.5)
    return Array(sol[end])
end


rng = Xoshiro(0)
b = [0.0 1.0; 1.0 0.0]
a = 0.6
n = length(b[1, :])
println("n:", n)
println("b:", b)
i1 = 0.18
i2 = 2.5 
timespan = (0.0, 5.0)
ans = simulate(i1, i2, a, b, timespan)

display(ans)

inputs = [i1, i2]
input_size = length(inputs)
output_size = length(a) + length(b)
nn = Chain(
    Dense(input_size, input_size*3*n, tanh),
    Dense(input_size*3*n, output_size*2, tanh),
    Dense(output_size*2, output_size, sigmoid)
)

u, st = Lux.setup(rng, nn)

function predict_neuralode(u)
    # Get parameters from the neural network
    output, outst = nn(inputs, u, st)

    # Segregate the output
    p_a = output[1]
    pp_b = output[length(a)+1:end]
    p_b = zeros(n, n)
    index = 1
    for i in 1:n
        for j in 1:n
            p_b[i, j] = pp_b[index]
            index += 1
        end
    end

    nn_output = [p_a, p_b]
    #println("nn_output: ", nn_output)
    pred = simulate(i1, i2, p_a, p_b, timespan)
    return Array(pred)
end

function loss_neuralode(ans, u)
    pred = predict_neuralode(u)
    loss = sum(abs2, ans .- pred)
    return [loss], [pred]
end

function loss!(loss, ans, pinit)
    loss = loss_neuralode(ans, pinit)[1]
    return nothing
end

pinit = ComponentArray(u)

pred = predict_neuralode(u)
println("Training data: ", size(ans))
println("Prediction:", size(pred))

loss, pred = loss_neuralode(ans, u)
dloss = zero(loss)
dp = zero(pinit)
Enzyme.autodiff(Reverse, loss!, Duplicated(loss, dloss), Const(ans), Duplicated(pinit, dp))

I also tried to autodiff the ODE itself, but It failed with the error below when I ran it w/ or w/o vectorizing the inputs.

function evolve!(dc, c, p, t)
    p1 = p[1]
    p2 = p[2]
    dc .= c .* p2 * p1
end

function simulate(i1, i2, a, b, t_span)
    p2 = exp(-i2[1] * a[1])
    p1 = i1[1] * b
    p = (p1, p2)
    c0 = [1.0 0.0; 1.0 0.0]
    prob = ODEProblem(evolve!, c0, t_span, p)
    sol = solve(prob, Euler(), save_everystep=false, dt = 0.5)
    return Array(sol[end])
end

function simulate!(sol, i1, i2, a, b, t_span)
    sol .= simulate(i1, i2, a, b, t_span)
    return nothing
end

rng = Xoshiro(0)
b = [0.0 1.0; 1.0 0.0]
a = [0.6]
n = length(b[1, :])
println("n:", n)
println("b:", b)
i1 = [0.18]
i2 = [2.5]
timespan = (0.0, 5.0)
sol = simulate(i1, i2, a, b, timespan)
di1 = zero(i1)
di2 = zero(i2)
da = zero(a)
db = zero(b)
dsol = zero(sol)
dsol[1,1] = 1.0 
gs = Enzyme.autodiff(Reverse, simulate!, Duplicated(sol, dsol), Duplicated(i1,di1), Duplicated(i2,di2), Duplicated(a,da), Duplicated(b,db), Const(timespan))
display(ans)
ERROR: StackOverflowError:
Stacktrace:
     [1] hvcat
       @ ~/.julia/juliaup/julia-1.11.1+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/SparseArrays/src/sparsevector.jl:1307 [inlined]
     [2] hvcat
       @ ~/.julia/juliaup/julia-1.11.1+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/SparseArrays/src/sparsevector.jl:0 [inlined]
     [3] augmented_julia_hvcat_100084_inner_1wrap
       @ ~/.julia/juliaup/julia-1.11.1+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/SparseArrays/src/sparsevector.jl:0
     [4] macro expansion
       @ ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:8305 [inlined]
     [5] enzyme_call
       @ ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:7868 [inlined]
     [6] AugmentedForwardThunk
       @ ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:7705 [inlined]
     [7] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(hvcat), df::Nothing, primal_1::Tuple{…}, shadow_1_1::Nothing, primal_2::Float64, shadow_2_1::Base.RefValue{…}, primal_3::Float64, shadow_3_1::Base.RefValue{…}, primal_4::Float64, shadow_4_1::Base.RefValue{…}, primal_5::Float64, shadow_5_1::Base.RefValue{…})
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/RvNgp/src/rules/jitrules.jl:483
--- the above 7 lines are repeated 4795 more times ---
 [33573] hvcat
       @ ~/.julia/juliaup/julia-1.11.1+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/SparseArrays/src/sparsevector.jl:1307 [inlined]
 [33574] hvcat
       @ ~/.julia/juliaup/julia-1.11.1+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/SparseArrays/src/sparsevector.jl:0 [inlined]
 [33575] augmented_julia_hvcat_99339_inner_1wrap
       @ ~/.julia/juliaup/julia-1.11.1+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/SparseArrays/src/sparsevector.jl:0
 [33576] macro expansion
       @ ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:8305 [inlined]
 [33577] enzyme_call
       @ ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:7868 [inlined]
 [33578] AugmentedForwardThunk
       @ ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:7705 [inlined]
 [33579] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(hvcat), df::Nothing, primal_1::Tuple{…}, shadow_1_1::Nothing, primal_2::Float64, shadow_2_1::Nothing, primal_3::Float64, shadow_3_1::Nothing, primal_4::Float64, shadow_4_1::Nothing, primal_5::Float64, shadow_5_1::Nothing)
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/RvNgp/src/rules/jitrules.jl:483
 [33580] hvcat
       @ ~/.julia/juliaup/julia-1.11.1+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/SparseArrays/src/sparsevector.jl:1307 [inlined]
 [33581] simulate
       @ ~/Library/CloudStorage/[email protected]/My Drive/Ga2O3/Neural-Mechanistic-Model/simple autodiff.jl:24
 [33582] simulate!
       @ ~/Library/CloudStorage/[email protected]/My Drive/Ga2O3/Neural-Mechanistic-Model/simple autodiff.jl:31 [inlined]
 [33583] simulate!
       @ ~/Library/CloudStorage/[email protected]/My Drive/Ga2O3/Neural-Mechanistic-Model/simple autodiff.jl:0 [inlined]
 [33584] diffejulia_simulate__101985_inner_3wrap
       @ ~/Library/CloudStorage/[email protected]/My Drive/Ga2O3/Neural-Mechanistic-Model/simple autodiff.jl:0
 [33585] macro expansion
       @ ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:8305 [inlined]
 [33586] enzyme_call
       @ ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:7868 [inlined]
 [33587] CombinedAdjointThunk
       @ ~/.julia/packages/Enzyme/RvNgp/src/compiler.jl:7641 [inlined]
 [33588] autodiff
       @ ~/.julia/packages/Enzyme/RvNgp/src/Enzyme.jl:491 [inlined]
 [33589] autodiff
       @ ~/.julia/packages/Enzyme/RvNgp/src/Enzyme.jl:537 [inlined]
 [33590] autodiff(::ReverseMode{…}, ::typeof(simulate!), ::Duplicated{…}, ::Duplicated{…}, ::Duplicated{…}, ::Duplicated{…}, ::Duplicated{…}, ::Const{…})
       @ Enzyme ~/.julia/packages/Enzyme/RvNgp/src/Enzyme.jl:504
 [33591] eval
       @ ./boot.jl:430 [inlined]

@heyyeahcrow
Copy link
Author

heyyeahcrow commented Dec 2, 2024

See https://docs.sciml.ai/SciMLSensitivity/stable/faq/#How-do-I-isolate-potential-gradient-issues-and-improve-performance?

Hi, I tried to run this but need help understanding how it works. It eventually passed for the ODE itself without error or warning but did not save any sensitivity.

  1. In the usual case, I must assign 1.0 to du to calculate the related sensitivity. Do you know where I should do this in such a form? (Is it tmp4?)
  2. Why do we zero ytmp instead of using ytmp = u0?

Also, can you explain more about how this is related to my problems? I don't see the connections between the warnings and the error I got.

@heyyeahcrow
Copy link
Author

A quick update. I installed Julia 1.10 and updated the packages to run the autodiff for loss!. The error is shown below.

ERROR: StackOverflowError:
Stacktrace:
     [1] hvcat
       @ C:\Users\heyye\.julia\juliaup\julia-1.10.7+0.x64.w64.mingw32\share\julia\stdlib\v1.10\SparseArrays\src\sparsevector.jl:1270 [inlined]
     [2] hvcat
       @ C:\Users\heyye\.julia\juliaup\julia-1.10.7+0.x64.w64.mingw32\share\julia\stdlib\v1.10\SparseArrays\src\sparsevector.jl:0 [inlined]
     [3] augmented_julia_hvcat_16568_inner_1wrap
       @ C:\Users\heyye\.julia\juliaup\julia-1.10.7+0.x64.w64.mingw32\share\julia\stdlib\v1.10\SparseArrays\src\sparsevector.jl:0
     [4] macro expansion
       @ C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\compiler.jl:6229 [inlined]
     [5] enzyme_call
       @ C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\compiler.jl:5775 [inlined]
     [6] AugmentedForwardThunk
       @ C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\compiler.jl:5697 [inlined]
     [7] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(hvcat), df::Nothing, primal_1::Tuple{…}, shadow_1_1::Nothing, primal_2::Float64, shadow_2_1::Base.RefValue{…}, primal_3::Float64, shadow_3_1::Base.RefValue{…}, primal_4::Float64, shadow_4_1::Base.RefValue{…}, primal_5::Float64, shadow_5_1::Base.RefValue{…})
       @ Enzyme.Compiler C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\rules\jitrules.jl:480
--- the last 7 lines are repeated 7246 more times ---
 [50730] hvcat
       @ C:\Users\heyye\.julia\juliaup\julia-1.10.7+0.x64.w64.mingw32\share\julia\stdlib\v1.10\SparseArrays\src\sparsevector.jl:1270 [inlined]
 [50731] hvcat
       @ C:\Users\heyye\.julia\juliaup\julia-1.10.7+0.x64.w64.mingw32\share\julia\stdlib\v1.10\SparseArrays\src\sparsevector.jl:0 [inlined]
 [50732] augmented_julia_hvcat_16544_inner_1wrap
       @ C:\Users\heyye\.julia\juliaup\julia-1.10.7+0.x64.w64.mingw32\share\julia\stdlib\v1.10\SparseArrays\src\sparsevector.jl:0
 [50733] macro expansion
       @ C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\compiler.jl:6229 [inlined]
 [50734] enzyme_call
       @ C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\compiler.jl:5775 [inlined]
 [50735] AugmentedForwardThunk
       @ C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\compiler.jl:5697 [inlined]
 [50736] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(hvcat), df::Nothing, primal_1::Tuple{…}, shadow_1_1::Nothing, primal_2::Float64, shadow_2_1::Nothing, primal_3::Float64, shadow_3_1::Nothing, primal_4::Float64, shadow_4_1::Nothing, primal_5::Float64, shadow_5_1::Nothing)
       @ Enzyme.Compiler C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\rules\jitrules.jl:480
 [50737] hvcat
       @ C:\Users\heyye\.julia\juliaup\julia-1.10.7+0.x64.w64.mingw32\share\julia\stdlib\v1.10\SparseArrays\src\sparsevector.jl:1270 [inlined]
 [50738] simulate
       @ G:\My Drive\Ga2O3\Neural-Mechanistic-Model\simple copy 3.jl:24 [inlined]
 [50739] simulate
       @ G:\My Drive\Ga2O3\Neural-Mechanistic-Model\simple copy 3.jl:0 [inlined]
 [50740] augmented_julia_simulate_16494_inner_1wrap
       @ G:\My Drive\Ga2O3\Neural-Mechanistic-Model\simple copy 3.jl:0
 [50741] macro expansion
       @ C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\compiler.jl:6229 [inlined]
 [50742] enzyme_call
       @ C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\compiler.jl:5775 [inlined]
 [50743] AugmentedForwardThunk
       @ C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\compiler.jl:5697 [inlined]
 [50744] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(simulate), df::Nothing, primal_1::Float64, shadow_1_1::Nothing, primal_2::Float64, shadow_2_1::Nothing, primal_3::Float64, shadow_3_1::Base.RefValue{…}, primal_4::Matrix{…}, shadow_4_1::Matrix{…}, primal_5::Tuple{…}, shadow_5_1::Nothing)
       @ Enzyme.Compiler C:\Users\heyye\.julia\packages\Enzyme\fpA3W\src\rules\jitrules.jl:480
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link
Member

wsmoses commented Dec 4, 2024

This isssue is a due to an issue in SparseArrays.jl, a fix had been made upstream but is still waiting for Julia 1.10.8 to be released

@heyyeahcrow
Copy link
Author

@wsmoses
Some packages seem to have been updated recently, and the previous problems have not appeared again.
The error from each test now converges to the same error:
First run:

ERROR: Enzyme.Compiler.EnzymeRuntimeActivityError(Cstring(0x000001c3ffed3d86))
Stacktrace:
 [1] reshape
   @ .\reshapedarray.jl:54 [inlined]
 [2] reshape
   @ .\reshapedarray.jl:129 [inlined]
 [3] reshape
   @ .\reshapedarray.jl:128 [inlined]
 [4] make_abstract_matrix
   @ C:\Users\heyye\.julia\packages\Lux\CXGnc\src\utils.jl:204 [inlined]
 [5] Dense
   @ C:\Users\heyye\.julia\packages\Lux\CXGnc\src\layers\basic.jl:343

Second run in the same terminal:

ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for:   %.pn69 = phi {} addrspace(10)* [ %60, %L205 ], [ %49, %L198 ] const val:   %49 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %48, align 8, !dbg !523, !tbaa !525, !alias.scope !504, !noalias !505, !dereferenceable_or_null !530, !align !339, !enzyme_type !531, !enzymejl_byref_MUT_REF !0, !enzymejl_source_type_Memory\7BFloat64\7D !0
 value=Unknown object of type Memory{Float64}
 llvalue=  %49 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %48, align 8, !dbg !523, !tbaa !525, !alias.scope !504, !noalias !505, !dereferenceable_or_null !530, !align !339, !enzyme_type !531, !enzymejl_byref_MUT_REF !0, !enzymejl_source_type_Memory\7BFloat64\7D !0

Stacktrace:
 [1] reshape
   @ .\reshapedarray.jl:60
 [2] reshape
   @ .\reshapedarray.jl:129
 [3] reshape
   @ .\reshapedarray.jl:128
 [4] make_abstract_matrix
   @ C:\Users\heyye\.julia\packages\Lux\CXGnc\src\utils.jl:204
 [5] Dense
   @ C:\Users\heyye\.julia\packages\Lux\CXGnc\src\layers\basic.jl:343

Stacktrace:
 [1] reshape
   @ .\reshapedarray.jl:54 [inlined]
 [2] reshape
   @ .\reshapedarray.jl:129 [inlined]
 [3] reshape
   @ .\reshapedarray.jl:128 [inlined]
 [4] make_abstract_matrix
   @ C:\Users\heyye\.julia\packages\Lux\CXGnc\src\utils.jl:204 [inlined]
 [5] Dense
   @ C:\Users\heyye\.julia\packages\Lux\CXGnc\src\layers\basic.jl:343

@heyyeahcrow
Copy link
Author

@wsmoses

Hi, I'm wondering if there are any updates on this since I don't really see 1.10.8 coming.

I tested the lux.jl itself without incorporating ODEs, and it shows a similar error but a different stacktrace.

using Lux, Printf, Statistics
using ComponentArrays
using Optimization, OptimizationOptimisers
#using Optimisers
using Enzyme
using Dates
using Random
using StaticArrays

function generate_data(rng::AbstractRNG)
    x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
    return (x, y)
end
rng = MersenneTwister()
Random.seed!(rng, 12345)

(x, y) = generate_data(rng)

# Define a simple neural network
nn = Chain(Dense(1 => 16, relu), Dense(16 => 1))

# Initialize the parameters
ps, st = Lux.setup(rng, nn)

# Define the prediction function
function predict(p, x)
    y_pred, _ = nn(x, p, st)
    return y_pred
end

# Define the loss function
function loss_neuralode(p)
    y_pred = predict(p, x)
    return [sum(abs2, y .- y_pred)]
end

function loss!(loss, p)
    loss[1] = loss_neuralode(p)[1]
    return nothing
end

ypred = predict(ps, x)


loss = loss_neuralode(ps)
dloss = zero(loss)
dloss[1] = 1.0
dp = make_zero(ps)

Enzyme.autodiff(Reverse, loss!, Duplicated(loss, dloss), Duplicated(ps, dp))
Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for:   %4 = insertvalue [2 x {} addrspace(10)*] zeroinitializer, {} addrspace(10)* %0, 0, !dbg !19 const val: {} addrspace(10)* %0
 value=Unknown object of type Matrix{Float32}
 llvalue={} addrspace(10)* %0

Stacktrace:
 [1] broadcasted
   @ .\broadcast.jl:1328
 [2] broadcasted
   @ .\broadcast.jl:1326
 [3] broadcasted
   @ .\broadcast.jl:0

Stacktrace:
 [1] broadcasted
   @ .\broadcast.jl:1328 [inlined]
 [2] broadcasted
   @ .\broadcast.jl:1326 [inlined]
 [3] broadcasted
   @ .\broadcast.jl:0 [inlined]
 [4] augmented_julia_broadcasted_45975_inner_1wrap
   @ .\broadcast.jl:0

However, it works if I do the set_runtime_activity(Reverse), while this method doesn't work on the original ODE case and showed the warning and error below.

┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Falling back to generic implementation. This may be slow.
└ @ LuxLib.Impl C:\Users\heyye\.julia\packages\LuxLib\qJGVS\src\impl\matmul.jl:145
ERROR: Enzyme execution failed.
Enzyme: Non-constant keyword argument found for Tuple{UInt64, typeof(Core.kwcall), Duplicated{@NamedTuple{save_everystep::Bool, dt::Float64}}, typeof(EnzymeCore.EnzymeRules.augmented_primal), EnzymeCore.EnzymeRules.RevConfigWidth{1, true, true, (false, true, false, true, false, false), true}, Const{typeof(DiffEqBase.solve_up)}, Type{Duplicated{Any}}, Duplicated{ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, true, Tuple{Matrix{Float64}, Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(evolve!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}}, Const{Nothing}, Duplicated{Matrix{Float64}}, MixedDuplicated{Tuple{Matrix{Float64}, Float64}}, Const{Euler}}

I furthered adapted my ODE solver from sol = solve(prob, Euler(), save_everystep = false, dt = 0.5)
to sol = solve(prob, Tsit5()), and the error becomes:

ERROR: `p` is not a SciMLStructure. This is required for adjoint sensitivity analysis. For more information,
see the documentation on SciMLStructures.jl for the definition of the SciMLStructures interface.
In particular, adjoint sensitivities only applies to `Tunable`.

Stacktrace:
 [1] automatic_sensealg_choice(prob::ODEProblem{…}, u0::Matrix{…}, p::Tuple{…}, verbose::Bool, repack::typeof(identity))
   @ SciMLSensitivity C:\Users\heyye\.julia\packages\SciMLSensitivity\RQ8Av\src\concrete_solve.jl:87
 [2] _concrete_solve_adjoint(::ODEProblem{…}, ::Tsit5{…}, ::Nothing, ::Matrix{…}, ::Tuple{…}, ::SciMLBase.EnzymeOriginator; verbose::Bool, kwargs::@Kwargs{})
   @ SciMLSensitivity C:\Users\heyye\.julia\packages\SciMLSensitivity\RQ8Av\src\concrete_solve.jl:274
 [3] _concrete_solve_adjoint(::ODEProblem{…}, ::Tsit5{…}, ::Nothing, ::Matrix{…}, ::Tuple{…}, ::SciMLBase.EnzymeOriginator)
   @ SciMLSensitivity C:\Users\heyye\.julia\packages\SciMLSensitivity\RQ8Av\src\concrete_solve.jl:245
 [4] #_solve_adjoint#75
   @ C:\Users\heyye\.julia\packages\DiffEqBase\R2Vjs\src\solve.jl:1585
Some type information was truncated. Use `show(err)` to see complete types.

I'm currently looking into https://sciml.github.io/SciMLStructures.jl/stable/example/ to see if there's any clue to solve this.
Do you have any suggestions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants