diff --git a/Project.toml b/Project.toml index 7a2f335b..0e292d64 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DifferentiableEigen = "73a20539-4e65-4dcb-a56d-dc20f210a01b" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" FMIImport = "9fcbc62e-52a0-44e9-a616-1359a0008194" +FMISensitivity = "3e748fe5-cd7f-4615-8419-3159287187d2" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -18,13 +19,13 @@ ThreadPools = "b189fb0b-2eb5-4ed4-bc0c-d34c51242431" [compat] Colors = "0.12.8" -DiffEqCallbacks = "2.26.0" DifferentiableEigen = "0.2.0" -DifferentialEquations = "7.8.0" -FMIImport = "0.15.8" -Flux = "0.13, 0.14" +DifferentialEquations = "7.10.0 - 7.11.0" +FMIImport = "0.16.2" +FMISensitivity = "0.1.0" +Flux = "0.14" Optim = "1.7.0" -ProgressMeter = "1.7.0" +ProgressMeter = "1.7.0 - 1.9.0" Requires = "1.3.0" ThreadPools = "2.1.1" julia = "1.6" diff --git a/README.md b/README.md index 193c881d..d5107a40 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ You can evaluate FMUs inside of your loss function. - building and training NeuralFMUs consisiting of multiple FMUs - building and training FMUINNs (PINNs) - different AD-frameworks: ForwardDiff.jl (CI-tested), ReverseDiff.jl (CI-tested, default setting), FiniteDiff.jl (not CI-tested) and Zygote.jl (not CI-tested) +- use `Flux.jl` optimisers as well as the ones from `Optim.jl` - ... ## What is under development in FMIFlux.jl? @@ -62,6 +63,7 @@ To keep dependencies nice and clean, the original package [*FMI.jl*](https://git - [*FMIImport.jl*](https://github.com/ThummeTo/FMIImport.jl): Importing FMUs into Julia - [*FMIExport.jl*](https://github.com/ThummeTo/FMIExport.jl): Exporting stand-alone FMUs from Julia Code - [*FMICore.jl*](https://github.com/ThummeTo/FMICore.jl): C-code wrapper for the FMI-standard +- [*FMISensitivity.jl*](https://github.com/ThummeTo/FMISensitivity.jl): Static and dynamic sensitivities over FMUs - [*FMIBuild.jl*](https://github.com/ThummeTo/FMIBuild.jl): Compiler/Compilation dependencies for FMIExport.jl - [*FMIFlux.jl*](https://github.com/ThummeTo/FMIFlux.jl): Machine Learning with FMUs (differentiation over FMUs) - [*FMIZoo.jl*](https://github.com/ThummeTo/FMIZoo.jl): A collection of testing and example FMUs diff --git a/examples/src/advanced_hybrid_ME.ipynb b/examples/src/advanced_hybrid_ME.ipynb index 30e78709..6aef2e2a 100644 --- a/examples/src/advanced_hybrid_ME.ipynb +++ b/examples/src/advanced_hybrid_ME.ipynb @@ -65,10 +65,9 @@ "| 1. | Enter Package Manager via | ] |\n", "| 2. | Install FMI via | add FMI | \n", "| 3. | Install FMIFlux via | add FMIFlux | \n", - "| 4. | Install FMIZoo via | add FMIZoo | \n", - "| 5. | Install DifferentialEquations via | add DifferentialEquations | \n", - "| 6. | Install Plots via | add Plots | \n", - "| 7. | Install Random via | add Random | " + "| 4. | Install FMIZoo via | add FMIZoo | \n", + "| 5. | Install Plots via | add Plots | \n", + "| 6. | Install Random via | add Random | " ] }, { @@ -100,7 +99,7 @@ "using FMIFlux\n", "using FMIFlux.Flux\n", "using FMIZoo\n", - "using DifferentialEquations: Tsit5\n", + "using FMI.DifferentialEquations: Tsit5\n", "using Statistics: mean, std\n", "import Plots\n", "\n", @@ -120,7 +119,7 @@ "\n", "![svg](https://github.com/thummeto/FMIFlux.jl/blob/main/docs/src/examples/img/SpringPendulum1D.svg?raw=true)\n", "\n", - "In contrast, the model *SpringFrictionPendulum1D* (*realFMU*) is somewhat more accurate, because it includes a friction component. \n", + "In contrast, the model *SpringFrictionPendulum1D* (*fricFMU*) is somewhat more accurate, because it includes a friction component. \n", "\n", "![svg](https://github.com/thummeto/FMIFlux.jl/blob/main/docs/src/examples/img/SpringFrictionPendulum1D.svg?raw=true)" ] @@ -156,9 +155,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### RealFMU\n", + "### *fricFMU*\n", "\n", - "In the next lines of code the FMU of the *realFMU* model from *FMIZoo.jl* is loaded and the information about the FMU is shown." + "In the next lines of code the FMU of the *fricFMU* model from *FMIZoo.jl* is loaded and the information about the FMU is shown." ] }, { @@ -175,15 +174,15 @@ }, "outputs": [], "source": [ - "realFMU = fmiLoad(\"SpringFrictionPendulum1D\", \"Dymola\", \"2022x\")\n", - "fmiInfo(realFMU)" + "fricFMU = fmiLoad(\"SpringFrictionPendulum1D\", \"Dymola\", \"2022x\")\n", + "fmiInfo(fricFMU)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "In the function fmiSimulate() the *realFMU* is simulated, still specifying the start and end time, the parameters and which variables are recorded. After the simulation is finished the result of the *realFMU* can be plotted. This plot also serves as a reference for the other model (*simpleFMU*)." + "In the function fmiSimulate() the *fricFMU* is simulated, still specifying the start and end time, the parameters and which variables are recorded. After the simulation is finished the result of the *fricFMU* can be plotted. This plot also serves as a reference for the other model (*simpleFMU*)." ] }, { @@ -201,7 +200,7 @@ "outputs": [], "source": [ "vrs = [\"mass.s\", \"mass.v\", \"mass.a\", \"mass.f\"]\n", - "realSimData = fmiSimulate(realFMU, (tStart, tStop); recordValues=vrs, saveat=tSave)\n", + "realSimData = fmiSimulate(fricFMU, (tStart, tStop); recordValues=vrs, saveat=tSave)\n", "fmiPlot(realSimData)" ] }, @@ -209,7 +208,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The data from the simulation of the *realFMU*, are divided into position and velocity data. These data will be needed later. " + "The data from the simulation of the *fricFMU*, are divided into position and velocity data. These data will be needed later. " ] }, { @@ -273,7 +272,7 @@ }, "outputs": [], "source": [ - "fmiUnload(realFMU)" + "fmiUnload(fricFMU)" ] }, { @@ -282,7 +281,7 @@ "source": [ "### SimpleFMU\n", "\n", - "The following lines load, simulate and plot the *simpleFMU* just like the *realFMU*. The differences between both systems can be clearly seen from the plots. In the plot for the *realFMU* it can be seen that the oscillation continues to decrease due to the effect of the friction. If you simulate long enough, the oscillation would come to a standstill in a certain time. The oscillation in the *simpleFMU* behaves differently, since the friction was not taken into account here. The oscillation in this model would continue to infinity with the same oscillation amplitude. From this observation the desire of an improvement of this model arises. " + "The following lines load, simulate and plot the *simpleFMU* just like the *fricFMU*. The differences between both systems can be clearly seen from the plots. In the plot for the *fricFMU* it can be seen that the oscillation continues to decrease due to the effect of the friction. If you simulate long enough, the oscillation would come to a standstill in a certain time. The oscillation in the *simpleFMU* behaves differently, since the friction was not taken into account here. The oscillation in this model would continue to infinity with the same oscillation amplitude. From this observation the desire of an improvement of this model arises. " ] }, { @@ -343,7 +342,7 @@ "\n", "In order to train our model, a loss function must be implemented. The solver of the NeuralFMU can calculate the gradient of the loss function. The gradient descent is needed to adjust the weights in the neural network so that the sum of the error is reduced and the model becomes more accurate.\n", "\n", - "The loss function in this implementation consists of the mean squared error (mse) from the real position of the *realFMU* simulation (posReal) and the position data of the network (posNet).\n", + "The loss function in this implementation consists of the mean squared error (mse) from the real position of the *fricFMU* simulation (posReal) and the position data of the network (posNet).\n", "$$ e_{mse} = \\frac{1}{n} \\sum\\limits_{i=0}^n (posReal[i] - posNet[i])^2 $$\n", "A growing horizon is applied, whereby the horizon only goes over the first five values. For this horizon the mse is calculated." ] @@ -382,7 +381,7 @@ "source": [ "#### Function for plotting\n", "\n", - "In this section the function for plotting is defined. The function `plotResults()` creates a new figure object. In dieses figure objekt werden dann die aktuellsten Ergebnisse von *realFMU*, *simpleFMU* und *neuralFMU* gegenübergestellt. \n", + "In this section the function for plotting is defined. The function `plotResults()` creates a new figure object. In dieses figure objekt werden dann die aktuellsten Ergebnisse von *fricFMU*, *simpleFMU* und *neuralFMU* gegenübergestellt. \n", "\n", "To output the loss in certain time intervals, a callback is implemented as a function in the following. Here a counter is incremented, every twentieth pass the loss function is called and the average error is printed out." ] @@ -413,7 +412,7 @@ " legendfontsize=8, legend=:topright)\n", " \n", " Plots.plot!(fig, tSave, posSimple, label=\"SimpleFMU\", linewidth=2)\n", - " Plots.plot!(fig, tSave, posReal, label=\"RealFMU\", linewidth=2)\n", + " Plots.plot!(fig, tSave, posReal, label=\"fricFMU\", linewidth=2)\n", " Plots.plot!(fig, time, posNeural, label=\"NeuralFMU\", linewidth=2)\n", " fig\n", "end" @@ -658,7 +657,7 @@ "source": [ "#### Comparison of the plots\n", "\n", - "Here three plots are compared with each other and only the position of the mass is considered. The first plot represents the *simpleFMU*, the second represents the *realFMU* (reference) and the third plot represents the result after training the NeuralFMU. " + "Here three plots are compared with each other and only the position of the mass is considered. The first plot represents the *simpleFMU*, the second represents the *fricFMU* (reference) and the third plot represents the result after training the NeuralFMU. " ] }, { @@ -716,7 +715,7 @@ "source": [ "### Summary\n", "\n", - "Based on the plots, it can be seen that the NeuralFMU is able to adapt the friction model of the *realFMU*. After 1000 training steps, the curves already overlap quite well, but this can be further improved by longer training or a better initialization." + "Based on the plots, it can be seen that the NeuralFMU is able to adapt the friction model of the *fricFMU*. After 1000 training steps, the curves already overlap quite well, but this can be further improved by longer training or a better initialization." ] }, { diff --git a/src/FMIFlux.jl b/src/FMIFlux.jl index 42330e5e..a25c6161 100644 --- a/src/FMIFlux.jl +++ b/src/FMIFlux.jl @@ -71,6 +71,10 @@ using FMIImport: fmi2SetTime, fmi2CompletedIntegratorStep, fmi2GetEventIndicator using FMIImport: fmi2SampleJacobian, fmi2GetDirectionalDerivative, fmi2GetJacobian, fmi2GetJacobian! using FMIImport: fmi2True, fmi2False +import FMIImport.FMICore: fmi2ValueReferenceFormat + +include("optimiser.jl") +include("hotfixes.jl") include("convert.jl") include("flux_overload.jl") include("neural.jl") @@ -80,7 +84,7 @@ include("deprecated.jl") include("batch.jl") include("losses.jl") include("scheduler.jl") -#include("optimiser.jl") +include("compatibility_check.jl") # from Plots.jl # No export here, Plots.plot is extended if available. diff --git a/src/compatibility_check.jl b/src/compatibility_check.jl new file mode 100644 index 00000000..1365cdc1 --- /dev/null +++ b/src/compatibility_check.jl @@ -0,0 +1,206 @@ +# +# Copyright (c) 2021 Tobias Thummerer, Lars Mikelsons +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +# checks gradient determination for all available sensitivity configurations, see: +# https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/ +using FMISensitivity.SciMLSensitivity + +function checkSensalgs!(loss, neuralFMU::Union{ME_NeuralFMU, CS_NeuralFMU}; + gradients=(:ReverseDiff, :Zygote, :ForwardDiff), # :FiniteDiff is slow ... + max_msg_len=192, chunk_size=DEFAULT_CHUNK_SIZE, + OtD_autojacvecs=(false, true, TrackerVJP(), ZygoteVJP(), ReverseDiffVJP(false), ReverseDiffVJP(true)), # EnzymeVJP() deadlocks in the current release xD + OtD_sensealgs=(BacksolveAdjoint, InterpolatingAdjoint, QuadratureAdjoint), + OtD_checkpointings=(true, false), + DtO_sensealgs=(ReverseDiffAdjoint, ForwardDiffSensitivity, TrackerAdjoint), # TrackerAdjoint, ZygoteAdjoint freeze the REPL + multiObjective::Bool=false, + bestof::Int=2, + timeout_seconds::Real=60.0, + kwargs...) + + params = Flux.params(neuralFMU) + initial_sensalg = neuralFMU.fmu.executionConfig.sensealg + + best_timing = Inf + best_gradient = nothing + best_sensealg = nothing + + printstyled("Mode: Optimize-then-Discretize\n") + for gradient ∈ gradients + printstyled("\tGradient: $(gradient)\n") + + for sensealg ∈ OtD_sensealgs + printstyled("\t\tSensealg: $(sensealg)\n") + for checkpointing ∈ OtD_checkpointings + printstyled("\t\t\tCheckpointing: $(checkpointing)\n") + + if sensealg == QuadratureAdjoint && checkpointing + printstyled("\t\t\t\tQuadratureAdjoint doesn't implement checkpointing, skipping ...\n") + continue + end + + for autojacvec ∈ OtD_autojacvecs + printstyled("\t\t\t\tAutojacvec: $(autojacvec)\n") + + if sensealg ∈ (BacksolveAdjoint, InterpolatingAdjoint) + neuralFMU.fmu.executionConfig.sensealg = sensealg(; autojacvec=autojacvec, chunk_size=chunk_size, checkpointing=checkpointing) + else + neuralFMU.fmu.executionConfig.sensealg = sensealg(; autojacvec=autojacvec, chunk_size=chunk_size) + end + + call = () -> _tryrun(loss, params, gradient, chunk_size, 5, max_msg_len, multiObjective; timeout_seconds=timeout_seconds) + for i in 1:bestof + timing = call() + + if timing < best_timing + best_timing = timing + best_gradient = gradient + best_sensealg = neuralFMU.fmu.executionConfig.sensealg + end + end + + end + end + end + end + + printstyled("Mode: Discretize-then-Optimize\n") + for gradient ∈ gradients + printstyled("\tGradient: $(gradient)\n") + for sensealg ∈ DtO_sensealgs + printstyled("\t\tSensealg: $(sensealg)\n") + + if sensealg == ForwardDiffSensitivity + neuralFMU.fmu.executionConfig.sensealg = sensealg(; chunk_size=chunk_size, convert_tspan=true) + else + neuralFMU.fmu.executionConfig.sensealg = sensealg() + end + + call = () -> _tryrun(loss, params, gradient, chunk_size, 3, max_msg_len, multiObjective; timeout_seconds=timeout_seconds) + for i in 1:bestof + timing = call() + + if timing < best_timing + best_timing = timing + best_gradient = gradient + best_sensealg = neuralFMU.fmu.executionConfig.sensealg + end + end + + end + end + + neuralFMU.fmu.executionConfig.sensealg = initial_sensalg + + printstyled("------------------------------\nBest time: $(best_timing)\nBest gradient: $(best_gradient)\nBest sensealg: $(best_sensealg)\n", color=:blue) + + return nothing +end + +# Thanks to: +# https://discourse.julialang.org/t/help-writing-a-timeout-macro/16591/11 +function timeout(f, arg, seconds, fail) + tsk = @task f(arg...) + schedule(tsk) + Timer(seconds) do timer + istaskdone(tsk) || Base.throwto(tsk, InterruptException()) + end + try + fetch(tsk) + catch _; + fail + end +end + +function runGrads(loss, params, gradient, chunk_size, multiObjective) + tstart = time() + + grads = nothing + if multiObjective + dim = loss(params[1]) + grads = zeros(Float64, length(params[1]), length(dim)) + else + grads = zeros(Float64, length(params[1])) + end + + computeGradient!(grads, loss, params[1], gradient, chunk_size, multiObjective) + + timing = time() - tstart + + if length(grads[1]) == 1 + grads = [grads] + end + + return grads, timing +end + +function _tryrun(loss, params, gradient, chunk_size, ts, max_msg_len, multiObjective::Bool=false; print_stdout::Bool=true, print_stderr::Bool=true, timeout_seconds::Real=60.0) + + spacing = "" + for t in ts + spacing *= "\t" + end + + message = "" + color = :black + timing = Inf + + original_stdout = stdout + original_stderr = stderr + (rd_stdout, wr_stdout) = redirect_stdout(); + (rd_stderr, wr_stderr) = redirect_stderr(); + + try + + #grads, timing = timeout(runGrads, (loss, params, gradient, chunk_size, multiObjective), timeout_seconds, ([Inf], -1.0)) + grads, timing = runGrads(loss, params, gradient, chunk_size, multiObjective) + + if timing == -1.0 + message = spacing * "TIMEOUT\n" + color = :red + else + val = collect(sum(abs.(grad)) for grad in grads) + message = spacing * "SUCCESS | $(round(timing; digits=2))s | GradAbsSum: $(round.(val; digits=6))\n" + color = :green + end + + catch e + msg = "$(e)" + if length(msg) > max_msg_len + msg = msg[1:max_msg_len] * "..." + end + + message = spacing * "$(msg)\n" + color = :red + end + + redirect_stdout(original_stdout) + redirect_stderr(original_stderr) + close(wr_stdout) + close(wr_stderr) + + if print_stdout + msg = read(rd_stdout, String) + if length(msg) > 0 + if length(msg) > max_msg_len + msg = msg[1:max_msg_len] * "..." + end + printstyled(spacing * "STDOUT: $(msg)\n", color=:yellow) + end + end + + if print_stderr + msg = read(rd_stderr, String) + if length(msg) > 0 + if length(msg) > max_msg_len + msg = msg[1:max_msg_len] * "..." + end + printstyled(spacing * "STDERR: $(msg)\n", color=:yellow) + end + end + + printstyled(message, color=color) + + return timing +end \ No newline at end of file diff --git a/src/layers.jl b/src/layers.jl index 0ddb8c1b..5f8e2595 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -19,8 +19,14 @@ struct FMUParameterRegistrator{T} function FMUParameterRegistrator{T}(fmu::FMU2, p_refs::fmi2ValueReferenceFormat, p::AbstractArray{T}) where {T} @assert length(p_refs) == length(p) "`p_refs` and `p` need to be the same length!" p_refs = prepareValueReference(fmu, p_refs) - fmu.optim_p_refs = p_refs - fmu.optim_p = p + + fmu.default_p_refs = p_refs + fmu.default_p = p + for c in fmu.components + c.default_p_refs = p_refs + c.default_p = p + end + return new(fmu, p_refs, p) end @@ -31,8 +37,14 @@ end export FMUParameterRegistrator function (l::FMUParameterRegistrator)(x) - l.fmu.optim_p = l.p - l.fmu.optim_p_refs = l.p_refs + + l.fmu.default_p_refs = l.p_refs + l.fmu.default_p = l.p + for c in l.fmu.components + c.default_p_refs = l.p_refs + c.default_p = l.p + end + return x end diff --git a/src/neural.jl b/src/neural.jl index b8ec8882..f77ee334 100644 --- a/src/neural.jl +++ b/src/neural.jl @@ -17,7 +17,7 @@ using FMIImport.SciMLSensitivity.ReverseDiff: TrackedArray import FMIImport.SciMLSensitivity: InterpolatingAdjoint, ReverseDiffVJP import ThreadPools -using DiffEqCallbacks +using DifferentialEquations.DiffEqCallbacks using DifferentialEquations: ODEFunction, ODEProblem, solve using FMIImport: FMU2Component, FMU2Event, FMU2Solution, fmi2ComponentState, fmi2ComponentStateContinuousTimeMode, fmi2ComponentStateError, @@ -29,8 +29,13 @@ using Flux using Flux.Zygote using FMIImport.SciMLSensitivity: ForwardDiffSensitivity, InterpolatingAdjoint, ReverseDiffVJP, ZygoteVJP +import DifferentiableEigen + + +import FMIImport.FMICore: EMPTY_fmi2Real, EMPTY_fmi2ValueReference -zero_tgrad(u,p,t) = zero(u) +DEFAULT_PROGRESS_DESCR = "Simulating ME-NeuralFMU ..." +DEFAULT_CHUNK_SIZE = 32 """ The mutable struct representing an abstract (simulation mode unknown) NeuralFMU. @@ -73,35 +78,28 @@ mutable struct ME_NeuralFMU{M, P, R} <: NeuralFMU modifiedState::Bool - startState - stopState - startEventInfo - stopEventInfo - start_t - stop_t - execution_start::Real - function ME_NeuralFMU{M, P, R}(model::M, p::P, re::R) where {M, P, R} + function ME_NeuralFMU{M, R}(model::M, p::AbstractArray{<:Real}, re::R) where {M, R} inst = new() inst.model = model inst.p = p inst.re = re inst.x0 = nothing + inst.saveat = nothing inst.modifiedState = true - inst.startState = nothing - inst.stopState = nothing - - inst.startEventInfo = nothing - inst.stopEventInfo = nothing + # inst.startState = nothing + # inst.stopState = nothing + # inst.startEventInfo = nothing + # inst.stopEventInfo = nothing inst.customCallbacksBefore = [] inst.customCallbacksAfter = [] inst.execution_start = 0.0 - + return inst end end @@ -114,8 +112,7 @@ mutable struct CS_NeuralFMU{F, C} <: NeuralFMU fmu::F tspan - saveat - + re # restrucure function function CS_NeuralFMU{F, C}() where {F, C} @@ -127,29 +124,31 @@ mutable struct CS_NeuralFMU{F, C} <: NeuralFMU end end -function evaluateModel(nfmu::ME_NeuralFMU, c::FMU2Component, x) - @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" - - return nfmu.model(x) -end - -function evaluateModel(nfmu::ME_NeuralFMU, c::FMU2Component, dx, x) +function evaluateModel(nfmu::ME_NeuralFMU, c::FMU2Component, x; p=nfmu.p) @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" - dx[:] = nfmu.model(x) - - return nothing -end - -function evaluateReModel(nfmu::ME_NeuralFMU, c::FMU2Component, x, p) - @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" + # [ToDo]: Cache the restructure if possible + # if isnothing(nfmu.re_model) || p != nfmu.re_p + # nfmu.re_p = p # fast_copy!(nfmu, :re_p, p) + # nfmu.re_model = nfmu.re(p) + # end + # return nfmu.re_model(x) + nfmu.p = p return nfmu.re(p)(x) end -function evaluateReModel(nfmu::ME_NeuralFMU, c::FMU2Component, dx, x, p) +function evaluateModel(nfmu::ME_NeuralFMU, c::FMU2Component, dx, x; p=nfmu.p) @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" + # [ToDo]: Cache the restructure if possible + # if isnothing(nfmu.re_model) || p != nfmu.re_p + # nfmu.re_p = p # fast_copy!(nfmu, :re_p, p) + # nfmu.re_model = nfmu.re(p) + # end + # dx[:] = nfmu.re_model(x) + + nfmu.p = p dx[:] = nfmu.re(p)(x) return nothing @@ -160,8 +159,6 @@ end function startCallback(integrator, nfmu::ME_NeuralFMU, c::Union{FMU2Component, Nothing}, t) ignore_derivatives() do - #nfmu.solveCycle += 1 - #@debug "[$(nfmu.solveCycle)][FIRST STEP]" nfmu.execution_start = time() @@ -177,7 +174,6 @@ function startCallback(integrator, nfmu::ME_NeuralFMU, c::Union{FMU2Component, N @debug "No initial time events ..." end - #@assert fmi2EnterContinuousTimeMode(c) == fmi2StatusOK end return c @@ -188,14 +184,10 @@ function stopCallback(nfmu::ME_NeuralFMU, c::FMU2Component, t) @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" ignore_derivatives() do - #@debug "[$(nfmu.solveCycle)][LAST STEP]" - + t = unsense(t) @assert t == nfmu.tspan[end] "stopCallback(...): Called for non-start-point t=$(t)" - - #c = finishSolveFMU(nfmu.fmu, c, nfmu.freeInstance, nfmu.terminate) - end return c @@ -205,6 +197,7 @@ end function time_choice(nfmu::ME_NeuralFMU, c::FMU2Component, integrator, tStart, tStop) @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" + @assert c.fmu.executionConfig.handleTimeEvents "time_choice(...) was called, but execution config disables time events.\nPlease open a issue." @debug assert_integrator_valid(integrator) # last call may be after simulation end @@ -212,17 +205,11 @@ function time_choice(nfmu::ME_NeuralFMU, c::FMU2Component, integrator, tStart, t return nothing end - if !c.fmu.executionConfig.handleTimeEvents - return nothing - end - c.solution.evals_timechoice += 1 if c.eventInfo.nextEventTimeDefined == fmi2True - #@debug "time_choice(...): $(c.eventInfo.nextEventTime) at t=$(ForwardDiff.value(integrator.t))" - + if c.eventInfo.nextEventTime >= tStart && c.eventInfo.nextEventTime <= tStop - #@assert sizeof(integrator.t) == sizeof(c.eventInfo.nextEventTime) "The NeuralFMU/solver are initialized in $(sizeof(integrator.t))-bit-mode, but FMU events are defined in $(sizeof(c.eventInfo.nextEventTime))-bit. Please define your ANN in $(sizeof(c.eventInfo.nextEventTime))-bit mode." @debug "time_choice(...): At $(integrator.t) next time event announced @$(c.eventInfo.nextEventTime)s" return c.eventInfo.nextEventTime else @@ -231,93 +218,47 @@ function time_choice(nfmu::ME_NeuralFMU, c::FMU2Component, integrator, tStart, t return nothing end else - #@debug "time_choice(...): nothing at t=$(ForwardDiff.value(integrator.t))" return nothing end - - end -# Returns the event indicators for an FMU. -function condition(nfmu::ME_NeuralFMU, c::FMU2Component, out::SubArray{<:ForwardDiff.Dual{T, V, N}, A, B, C, D}, _x, t, integrator) where {T, V, N, A, B, C, D} # Event when event_f(u,t) == 0 +# [ToDo] for now, ReverseDiff (together with the rrule) seems to have a problem with the SubArray here (when `collect` it accesses array elements that are #undef), +# so I added an additional (single allocating) dispatch... +# Type is ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}[#undef, #undef, #undef, ...] +function condition!(nfmu::ME_NeuralFMU, c::FMU2Component, out::AbstractArray{<:ReverseDiff.TrackedReal}, x, t, integrator, handleEventIndicators) - @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" - @debug assert_integrator_valid(integrator) - - #@assert c.state == fmi2ComponentStateContinuousTimeMode "condition(...): Must be called in mode continuous time." - - # ToDo: set inputs here - #fmiSetReal(myFMU, InputRef, Value) - - t = undual(t) - x = undual(_x) - - # ToDo: Evaluate on light-weight model (sub-model) without fmi2GetXXX or similar and the bottom ANN - #c.t = t # this will auto-set time via fx-call! - c.next_t = t - evaluateModel(nfmu, c, x) + if !isassigned(out, 1) + logWarning(nfmu.fmu, "There is currently an issue with the condition buffer pre-allocation, the buffer can't be overwritten by the generated rrule.") + out[:] = zeros(fmi2Real, length(out)) + end - out_tmp = zeros(c.fmu.modelDescription.numberOfEventIndicators) - fmi2GetEventIndicators!(c, out_tmp) - - fd_set!(out, out_tmp) - - c.solution.evals_condition += 1 - - @debug assert_integrator_valid(integrator) - - return nothing -end -function condition(nfmu::ME_NeuralFMU, c::FMU2Component, out::SubArray{<:ReverseDiff.TrackedReal}, _x, t, integrator) + invoke(condition!, Tuple{ME_NeuralFMU, FMU2Component, Any, Any, Any, Any, Any}, nfmu, c, out, x, t, integrator, handleEventIndicators) - @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" - @debug assert_integrator_valid(integrator) - - #@assert c.state == fmi2ComponentStateContinuousTimeMode "condition(...): Must be called in mode continuous time." - - # ToDo: set inputs here - #fmiSetReal(myFMU, InputRef, Value) - - t = untrack(t) - x = untrack(_x) - - # ToDo: Evaluate on light-weight model (sub-model) without fmi2GetXXX or similar and the bottom ANN - c.next_t = t - evaluateModel(nfmu, c, x) - - out_tmp = zeros(c.fmu.modelDescription.numberOfEventIndicators) - fmi2GetEventIndicators!(c, out_tmp) - - rd_set!(out, out_tmp) - - @debug assert_integrator_valid(integrator) - - c.solution.evals_condition += 1 - return nothing end function condition(nfmu::ME_NeuralFMU, c::FMU2Component, out, _x, t, integrator) # Event when event_f(u,t) == 0 - @debug assert_integrator_valid(integrator) - @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" +function condition!(nfmu::ME_NeuralFMU, c::FMU2Component, out, x, t, integrator, handleEventIndicators) - @debug @assert c.state == fmi2ComponentStateContinuousTimeMode "condition(...): Must be called in mode continuous time." - - #@debug "State condition..." + @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" + @assert c.state == fmi2ComponentStateContinuousTimeMode "condition!(...):\n" * FMICore.ERR_MSG_CONT_TIME_MODE - # ToDo: set inputs here - #fmiSetReal(myFMU, InputRef, Value) + # [ToDo] Evaluate on light-weight model (sub-model) without fmi2GetXXX or similar and the bottom ANN. + # Basically only the layers from very top to FMU need to be evaluated here. + c.default_t = t + c.default_ec = out + c.default_ec_idcs = handleEventIndicators + evaluateModel(nfmu, c, x) - t = unsense(t) - x = unsense(_x) + # [TODO] for generic applications, reset to previous values + c.default_t = -1.0 + c.default_ec = EMPTY_fmi2Real + c.default_ec_idcs = EMPTY_fmi2ValueReference - # ToDo: Evaluate on light-weight model (sub-model) without fmi2GetXXX or similar and the bottom ANN - c.next_t = t - evaluateModel(nfmu, c, x) # evaluate NeuralFMU (set new states) + # write back to condition buffer + out[:] = c.eval_output.ec # [ToDo] This seems not to be necessary, because of `c.default_ec = out` - fmi2GetEventIndicators!(c, out) - - @debug assert_integrator_valid(integrator) + c.solution.evals_condition += 1 c.solution.evals_condition += 1 @@ -329,30 +270,25 @@ global lastIndicatorX = nothing global lastIndicatorT = nothing function conditionSingle(nfmu::ME_NeuralFMU, c::FMU2Component, index, _x, t, integrator) - @assert c.state == fmi2ComponentStateContinuousTimeMode "condition(...): Must be called in mode continuous time." + @assert c.state == fmi2ComponentStateContinuousTimeMode "condition(...):\n" * FMICore.ERR_MSG_CONT_TIME_MODE @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" - # ToDo: set inputs here - #fmiSetReal(myFMU, InputRef, Value) - - if c.fmu.executionConfig.handleEventIndicators != nothing && index ∉ c.fmu.executionConfig.handleEventIndicators + if c.fmu.handleEventIndicators != nothing && index ∉ c.fmu.handleEventIndicators return 1.0 end - t = unsense(t) - x = unsense(_x) - - global lastIndicator # , lastIndicatorX, lastIndicatorT + global lastIndicator if lastIndicator == nothing || length(lastIndicator) != c.fmu.modelDescription.numberOfEventIndicators lastIndicator = zeros(c.fmu.modelDescription.numberOfEventIndicators) end - # ToDo: Input Function - - # ToDo: Evaluate on light-weight model (sub-model) without fmi2GetXXX or similar and the bottom ANN - c.next_t = t - evaluateModel(nfmu, c, x) # evaluate NeuralFMU (set new states) + # [ToDo] Evaluate on light-weight model (sub-model) without fmi2GetXXX or similar and the bottom ANN + c.default_t = t + c.default_ec = lastIndicator + evaluateModel(nfmu, c, x) + c.default_t = -1.0 + c.default_ec = EMPTY_fmi2Real fmi2GetEventIndicators!(c, lastIndicator) @@ -361,6 +297,7 @@ function conditionSingle(nfmu::ME_NeuralFMU, c::FMU2Component, index, _x, t, int return lastIndicator[index] end +# [ToDo] Check, that the new determined state is the right root of the event instant! function f_optim(x, nfmu::ME_NeuralFMU, c::FMU2Component, right_x_fmu) # , idx, direction::Real) # propagete the new state-guess `x` through the NeuralFMU evaluateModel(nfmu, c, x) @@ -374,63 +311,43 @@ function affectFMU!(nfmu::ME_NeuralFMU, c::FMU2Component, integrator, idx) @assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" @debug assert_integrator_valid(integrator) - @debug @assert c.state == fmi2ComponentStateContinuousTimeMode "affectFMU!(...): Must be in continuous time mode!" + @assert c.state == fmi2ComponentStateContinuousTimeMode "affectFMU!(...):\n" * FMICore.ERR_MSG_CONT_TIME_MODE - t = unsense(integrator.t) - x = unsense(integrator.u) - - # there are fx-evaluations before the event is handled, reset the FMU state to the current integrator step - mode = c.force - c.force = true - - c.next_t = t - evaluateModel(nfmu, c, x) # evaluate NeuralFMU (set new states) + # [NOTE] Here unsensing is OK, because we just want to reset the FMU to the correct state! + # The values come directly from the integrator and are NOT function arguments! + t = integrator.t # unsense(integrator.t) + x = integrator.u # unsense(integrator.u) - c.force = mode - - # if inputFunction !== nothing - # fmi2SetReal(c, inputValues, inputFunction(integrator.t)) - # end + if c.x != x + # capture status of `force` + mode = c.force + c.force = true + + # there are fx-evaluations before the event is handled, reset the FMU state to the current integrator step + c.default_t = t + evaluateModel(nfmu, c, x) # evaluate NeuralFMU (set new states) + # [NOTE] No need to reset time here, because we did pass a event instance! + # c.default_t = -1.0 + + c.force = mode + end fmi2EnterEventMode(c) - - ############# - - #left_x_fmu = fmi2GetContinuousStates(c) - - # Todo set inputs - - # Event found - handle it - #@assert fmi2EnterEventMode(c) == fmi2StatusOK handleEvents(c) - ignore_derivatives() do - if idx == 0 - #@debug "affectFMU!(...): Handle time event at t=$t" - end - - if idx > 0 - #@debug "affectFMU!(...): Handle state event at t=$t" - end - end - left_x = nothing right_x = nothing if c.eventInfo.valuesOfContinuousStatesChanged == fmi2True - left_x = x - + left_x = unsense(x) right_x_fmu = fmi2GetContinuousStates(c) # the new FMU state after handled events ignore_derivatives() do - #@debug "affectFMU!(_, _, $idx): NeuralFMU state event from $(left_x) (fmu: $(left_x_fmu)). Indicator [$idx]: $(indicators[idx]). Optimizing new state ..." + @debug "affectFMU!(_, _, $idx): NeuralFMU state event from $(left_x) (fmu: $(left_x_fmu)). Indicator [$idx]: $(indicators[idx]). Optimizing new state ..." end - # ToDo: Problem-related parameterization of optimize-call - #result = optimize(x_seek -> f_optim(x_seek, nfmu, right_x_fmu), left_x, LBFGS(); autodiff = :forward) - #result = Optim.optimize(x_seek -> f_optim(x_seek, nfmu, right_x_fmu, idx, sign(indicators[idx])), left_x, Optim.NelderMead()) - + # [ToDo] use gradient-based optimization here? # if there is an ANN above the FMU, propaget FMU state through top ANN: if nfmu.modifiedState == true result = Optim.optimize(x_seek -> f_optim(x_seek, nfmu, c, right_x_fmu), left_x, Optim.NelderMead()) @@ -439,19 +356,16 @@ function affectFMU!(nfmu::ME_NeuralFMU, c::FMU2Component, integrator, idx) right_x = right_x_fmu end - if isdual(integrator.u) - T, V, N = fd_eltypes(integrator.u) - - new_x = collect(ForwardDiff.Dual{T, V, N}(V(right_x[i]), ForwardDiff.partials(integrator.u[i])) for i in 1:length(integrator.u)) - #set_u!(integrator, new_x) - integrator.u .= new_x - - @debug "affectFMU!(_, _, $idx): NeuralFMU event with state change at $t. Indicator [$idx]. (ForwardDiff) " - else - #set_u!(integrator, right_x) - integrator.u .= right_x - - @debug "affectFMU!(_, _, $idx): NeuralFMU event with state change at $t. Indicator [$idx]." + # [ToDo] This should only be done in the frule/rrule, the actual affect should do a hard "set state" + for i in 1:length(left_x) + if left_x[i] != 0.0 + scale = right_x[i] ./ left_x[i] + integrator.u[i] *= scale + else # integrator state zero can't be scaled, need to add (but no sensitivities in this case!) + shift = right_x[i] - left_x[i] + integrator.u[i] += shift + logWarning(c.fmu, "Probably wrong sensitivities for ∂x^+ / ∂x^-\nCan't scale from zero state (state #$(i)=0.0)") + end end u_modified!(integrator, true) @@ -465,6 +379,7 @@ function affectFMU!(nfmu::ME_NeuralFMU, c::FMU2Component, integrator, idx) end if c.eventInfo.nominalsOfContinuousStatesChanged == fmi2True + # ToDo: Do something with that information, e.g. use for FiniteDiff sampling step size determination x_nom = fmi2GetNominalsOfContinuousStates(c) end @@ -512,15 +427,10 @@ end # Does one step in the simulation. function stepCompleted(nfmu::ME_NeuralFMU, c::FMU2Component, x, t, integrator, tStart, tStop) - #@assert getCurrentComponent(nfmu.fmu) == c "Thread `$(Threads.threadid())` wants to evaluate wrong component!" - @debug assert_integrator_valid(integrator) + assert_integrator_valid(integrator) c.solution.evals_stepcompleted += 1 - #@debug "Step" - # there might be no component (in Zygote)! - # @assert c.state == fmi2ComponentStateContinuousTimeMode "stepCompleted(...): Must be in continuous time mode." - if !isnothing(c.progressMeter) ProgressMeter.update!(c.progressMeter, floor(Integer, 1000.0*(t-tStart)/(tStop-tStart)) ) end @@ -534,17 +444,17 @@ function stepCompleted(nfmu::ME_NeuralFMU, c::FMU2Component, x, t, integrator, t if enterEventMode == fmi2True affectFMU!(nfmu, c, integrator, -1) - else - # ToDo: set inputs here - #fmiSetReal(myFMU, InputRef, Value) end - #@debug "Step completed at $(ForwardDiff.value(t)) with $(collect(ForwardDiff.value(xs) for xs in x))" + @debug "Step completed at $(ForwardDiff.value(t)) with $(collect(ForwardDiff.value(xs) for xs in x))" end @debug assert_integrator_valid(integrator) end +# [ToDo] (1) This must be in-place +# (2) getReal must be replaced with the inplace getter within c(...) +# (3) remove unsense to determine save value sensitivities # save FMU values function saveValues(nfmu::ME_NeuralFMU, c::FMU2Component, recordValues, _x, t, integrator) @@ -554,42 +464,33 @@ function saveValues(nfmu::ME_NeuralFMU, c::FMU2Component, recordValues, _x, t, i c.solution.evals_savevalues += 1 # ToDo: Evaluate on light-weight model (sub-model) without fmi2GetXXX or similar and the bottom ANN - c.next_t = t + c.default_t = t evaluateModel(nfmu, c, x) # evaluate NeuralFMU (set new states) # Todo set inputs - return (fmi2GetReal(c, recordValues)...,) end -# TODO -import DifferentiableEigen -function saveEigenvalues(nfmu::ME_NeuralFMU, c::FMU2Component, _x, p, _t, integrator, sensitivity) +function saveEigenvalues(nfmu::ME_NeuralFMU, c::FMU2Component, _x, _t, integrator, sensitivity::Symbol) - #@assert c.state == fmi2ComponentStateContinuousTimeMode "saveEigenvalues(...): Must be in continuous time mode." + @assert c.state == fmi2ComponentStateContinuousTimeMode "saveEigenvalues(...):\n" * FMICore.ERR_MSG_CONT_TIME_MODE c.solution.evals_saveeigenvalues += 1 - t = unsense(_t) - - c.next_t = t + c.default_t = t A = nothing - #if sensitivity == :ForwardDiff - A = ForwardDiff.jacobian(x -> evaluateReModel(nfmu, c, x, p), _x) # TODO: chunk_size! - # elseif sensitivity == :ReverseDiff - # A = ReverseDiff.jacobian(x -> evaluateReModel(nfmu, c, x, p), _x) - # elseif sensitivity == :Zygote - # A = Zygote.jacobian(x -> evaluateReModel(nfmu, c, x, p), _x)[1] - # elseif sensitivity == :none - # A = ForwardDiff.jacobian(x -> evaluateReModel(nfmu, c, x, p), unsense(_x)) - # end + if sensitivity == :ForwardDiff + A = ForwardDiff.jacobian(x -> evaluateModel(nfmu, c, x), _x) # TODO: chunk_size! + elseif sensitivity == :ReverseDiff + A = ReverseDiff.jacobian(x -> evaluateModel(nfmu, c, x), _x) + elseif sensitivity == :Zygote + A = Zygote.jacobian(x -> evaluateModel(nfmu, c, x), _x)[1] + elseif sensitivity == :none + A = ForwardDiff.jacobian(x -> evaluateModel(nfmu, c, x), unsense(_x)) + end eigs, _ = DifferentiableEigen.eigen(A) - # x = unsense(_x) - # c.next_t = t - # evaluateModel(nfmu, c, x) - return (eigs...,) end @@ -600,40 +501,20 @@ function fx(nfmu::ME_NeuralFMU, p,#::Array, t)#::Real) - #nanx = !any(isnan.(collect(any(isnan.(ForwardDiff.partials.(x[i]).values)) for i in 1:length(x)))) - #nanu = !any(isnan(ForwardDiff.partials(t))) - - #@assert nanx && nanu "NaN in start fx nanx = $nanx nanu = $nanu @ $(t)." - - if c === nothing + if isnothing(c) # this should never happen! return zeros(length(x)) end ignore_derivatives() do - t = unsense(t) - c.next_t = t + c.default_t = t end ############ - evaluateReModel(nfmu, c, dx, x, p) - - # if isdual(dx) - # dx_tmp = evaluateReModel(nfmu, c, x, p) - # fd_set!(dx, dx_tmp) - - # elseif istracked(dx) - # dx_tmp = evaluateReModel(nfmu, c, x, p) - # rd_set!(dx, dx_tmp) - # else - # #@info "dx: $(dx)" - # #@info "dx_tmp: $(dx_tmp)" - # evaluateReModel(nfmu, c, dx, x, p) - # end + evaluateModel(nfmu, c, dx, x; p=p) ignore_derivatives() do - c.solution.evals_fx_inplace += 1 end @@ -654,12 +535,10 @@ function fx(nfmu::ME_NeuralFMU, ignore_derivatives() do c.solution.evals_fx_outofplace += 1 - t = unsense(t) - c.next_t = t + c.default_t = t end - return evaluateReModel(nfmu, c, x, p) - + return evaluateModel(nfmu, c, x; p=p) end ##### EVENT HANDLING END @@ -667,28 +546,26 @@ end """ Constructs a ME-NeuralFMU where the FMU is at an arbitrary location inside of the NN. -# Arguents +# Arguments - `fmu` the considered FMU inside the NN - `model` the NN topology (e.g. Flux.chain) - `tspan` simulation time span - - `alg` a numerical ODE solver - - `convertParams` automatically convert ANN parameters to Float64 if not already + - `solver` an ODE Solver (default=`nothing`, heurisitically determine one) # Keyword arguments - - `saveat` time points to save the NeuralFMU output, if empty, solver step size is used (may be non-equidistant) - - `fixstep` forces fixed step integration - - `recordFMUValues` additionally records internal FMU variables (currently not supported because of open issues) + - `recordValues` additionally records internal FMU variables """ function ME_NeuralFMU(fmu::FMU2, model, tspan, solver=nothing; - saveat=[], recordValues = nothing, + saveat=nothing, kwargs...) if !is64(model) model = convert64(model) + logInfo(fmu, "Model is not Float64, but this is necessary for (Neural)FMUs.\nModel parameters are automatically converted to Float64.") end p, re = Flux.destructure(model) @@ -697,16 +574,14 @@ function ME_NeuralFMU(fmu::FMU2, ###### nfmu.fmu = fmu - + nfmu.saved_values = nothing nfmu.recordValues = prepareValueReference(fmu, recordValues) - # abstol=abstol, reltol=reltol, dtmin=dtmin, force_dtmin=force_dtmin, - nfmu.tspan = tspan - nfmu.saveat = saveat nfmu.solver = solver + nfmu.saveat = saveat nfmu.kwargs = kwargs nfmu.parameters = nothing @@ -716,48 +591,61 @@ function ME_NeuralFMU(fmu::FMU2, end """ -Constructs a CS-NeuralFMU where the FMU is at an arbitrary location inside of the NN. +Constructs a CS-NeuralFMU where the FMU is at an arbitrary location inside of the ANN. # Arguents - - `fmu` the considered FMU inside the NN - - `model` the NN topology (e.g. Flux.chain) + - `fmu` the considered FMU inside the ANN + - `model` the ANN topology (e.g. Flux.Chain) - `tspan` simulation time span # Keyword arguments - - `saveat` time points to save the NeuralFMU output, if empty, solver step size is used (may be non-equidistant) + - `recordValues` additionally records FMU variables """ -function CS_NeuralFMU(fmu::Union{FMU2, Vector{<:FMU2}}, +function CS_NeuralFMU(fmu::FMU2, model, tspan; - saveat=[], - recordValues = []) + recordValues=[]) if !is64(model) model = convert64(model) + logInfo(fmu, "Model is not Float64, but this is necessary for (Neural)FMUs.\nModel parameters are automatically converted to Float64.") end - nfmu = nothing - if typeof(fmu) == FMU2 - nfmu = CS_NeuralFMU{FMU2, FMU2Component}() - else - nfmu = CS_NeuralFMU{Vector{FMU2}, Vector{FMU2Component} }() - end + nfmu = CS_NeuralFMU{FMU2, FMU2Component}() nfmu.fmu = fmu + nfmu.model = model + nfmu.tspan = tspan + + return nfmu +end - nfmu.model = model # Chain(model.layers...) +function CS_NeuralFMU(fmus::Vector{<:FMU2}, + model, + tspan; + recordValues=[]) - nfmu.tspan = tspan - nfmu.saveat = saveat + if !is64(model) + model = convert64(model) + for fmu in fmus + logInfo(fmu, "Model is not Float64, but this is necessary for (Neural)FMUs.\nModel parameters are automatically converted to Float64.") + end + end - nfmu + nfmu = CS_NeuralFMU{Vector{FMU2}, Vector{FMU2Component} }() + + nfmu.fmu = fmus + nfmu.model = model + nfmu.tspan = tspan + + return nfmu end function checkExecTime(integrator, nfmu::ME_NeuralFMU, c, max_execution_duration::Real) dist = max(nfmu.execution_start + max_execution_duration - time(), 0.0) if dist <= 0.0 - logWarning(nfmu.fmu, "Reached max execution duration ($(max_execution_duration)), terminating integration ...") + logInfo(nfmu.fmu, "Reached max execution duration ($(max_execution_duration)), terminating integration ...") terminate!(integrator) end @@ -769,6 +657,9 @@ function getComponent(nfmu::NeuralFMU) end """ + + TODO: Signature, Arguments and Keyword-Arguments descriptions. + Evaluates the ME_NeuralFMU in the timespan given during construction or in a custom timespan from `t_start` to `t_stop` for a given start state `x_start`. # Keyword arguments @@ -797,12 +688,15 @@ function (nfmu::ME_NeuralFMU)(x_start::Union{Array{<:Real}, Nothing} = nfmu.x0, sensealg=nfmu.fmu.executionConfig.sensealg, # ToDo: AbstractSensitivityAlgorithm kwargs...) - if saveat[1] != tspan[1] - @warn "NeuralFMU changed time interval, start time is $(tspan[1]), but saveat from constructor gives $(saveat[1]). Please provide correct `saveat` via keyword with matching start/stop time." - end - if saveat[end] != tspan[end] - @warn "NeuralFMU changed time interval, stop time is $(tspan[end]), but saveat from constructor gives $(saveat[end]). Please provide correct `saveat` via keyword with matching start/stop time." - end + # this shouldnt be forced + # if !isnothing(saveat) + # if saveat[1] != tspan[1] + # @warn "NeuralFMU changed time interval, start time is $(tspan[1]), but saveat from constructor gives $(saveat[1]). Please provide correct `saveat` via keyword with matching start/stop time." + # end + # if saveat[end] != tspan[end] + # @warn "NeuralFMU changed time interval, stop time is $(tspan[end]), but saveat from constructor gives $(saveat[end]). Please provide correct `saveat` via keyword with matching start/stop time." + # end + # end recordValues = prepareValueReference(nfmu.fmu, recordValues) @@ -814,6 +708,7 @@ function (nfmu::ME_NeuralFMU)(x_start::Union{Array{<:Real}, Nothing} = nfmu.x0, nfmu.tspan = tspan nfmu.x0 = x_start + nfmu.p = p ignore_derivatives() do @debug "ME_NeuralFMU..." @@ -847,11 +742,6 @@ function (nfmu::ME_NeuralFMU)(x_start::Union{Array{<:Real}, Nothing} = nfmu.x0, push!(callbacks, cb) end - # cb = FunctionCallingCallback((x, t, integrator) -> @info "Start"; # startCallback(integrator, nfmu, c, t); - # funcat=[t_start], - # func_start=true) - # push!(callbacks, cb) - nfmu.fmu.hasStateEvents = (c.fmu.modelDescription.numberOfEventIndicators > 0) nfmu.fmu.hasTimeEvents = (c.eventInfo.nextEventTimeDefined == fmi2True) @@ -897,7 +787,7 @@ function (nfmu::ME_NeuralFMU)(x_start::Union{Array{<:Real}, Nothing} = nfmu.x0, (integrator) -> terminate!(integrator); save_positions=(false, false)) push!(callbacks, terminateCb) - #@info "Setting max execeution time to $(max_execution_duration)" + logInfo(nfmu.fmu, "Setting max execeution time to $(max_execution_duration)") end # custom callbacks @@ -906,7 +796,7 @@ function (nfmu::ME_NeuralFMU)(x_start::Union{Array{<:Real}, Nothing} = nfmu.x0, end if showProgress - c.progressMeter = ProgressMeter.Progress(1000; desc=progressDescr, color=:blue, dt=1.0) #, barglyphs=ProgressMeter.BarGlyphs("[=> ]")) + c.progressMeter = ProgressMeter.Progress(1000; desc=progressDescr, color=:blue, dt=1.0) ProgressMeter.update!(c.progressMeter, 0) # show it! else c.progressMeter = nothing @@ -918,6 +808,7 @@ function (nfmu::ME_NeuralFMU)(x_start::Union{Array{<:Real}, Nothing} = nfmu.x0, func_start=true) push!(callbacks, stepCb) + # [ToDo] Allow for AD-primitives for sensitivity analysis of recorded values if saving c.solution.values = SavedValues(Float64, Tuple{collect(Float64 for i in 1:length(recordValues))...}) c.solution.valueReferences = recordValues @@ -950,8 +841,8 @@ function (nfmu::ME_NeuralFMU)(x_start::Union{Array{<:Real}, Nothing} = nfmu.x0, c.solution.eigenvalues = SavedValues(recordEigenvaluesType, Tuple{dtypes...}) savingCB = nothing - if saveat === nothing - savingCB = SavingCallback((u,t,integrator) -> saveEigenvalues(nfmu, c, u, p, t, integrator, recordEigenvaluesSensitivity), + if isnothing(saveat) + savingCB = SavingCallback((u,t,integrator) -> saveEigenvalues(nfmu, c, u, t, integrator, recordEigenvaluesSensitivity), c.solution.eigenvalues) else savingCB = SavingCallback((u,t,integrator) -> saveEigenvalues(nfmu, c, u, p, t, integrator, recordEigenvaluesSensitivity), @@ -965,30 +856,11 @@ function (nfmu::ME_NeuralFMU)(x_start::Union{Array{<:Real}, Nothing} = nfmu.x0, prob = nothing - if inPlace - ff = ODEFunction{true}((dx, x, p, t) -> fx(nfmu, c, dx, x, p, t), - tgrad=nothing) - prob = ODEProblem{true}(ff, nfmu.x0, nfmu.tspan, p) - else - ff = ODEFunction{false}((x, p, t) -> fx(nfmu, c, x, p, t), - tgrad=nothing) # zero_tgrad) - prob = ODEProblem{false}(ff, nfmu.x0, nfmu.tspan, p) - end - - # if (length(callbacks) == 2) # only start and stop callback, so the system is pure continuous - # startCallback(nfmu, nfmu.tspan[1]) - # c.solution.states = solve(prob, nfmu.args...; sensealg=sensealg, saveat=nfmu.saveat, nfmu.kwargs...) - # stopCallback(nfmu, nfmu.tspan[end]) - # else - #c.solution.states = solve(prob, nfmu.args...; sensealg=sensealg, saveat=nfmu.saveat, callback = CallbackSet(callbacks...), nfmu.kwargs...) + ff = ODEFunction{true}((dx, x, p, t) -> fx(nfmu, c, dx, x, p, t)) # tgrad=nothing + prob = ODEProblem{true}(ff, nfmu.x0, nfmu.tspan, p) if isnothing(sensealg) - # when using state events, we (currently) need AD-through-Solver - if c.fmu.hasStateEvents && c.fmu.executionConfig.handleStateEvents - sensealg = ReverseDiffAdjoint() # Support for multi-state-event simulations, but a little bit slower than QuadratureAdjoint - else # otherwise, we can use the faster Adjoint-over-Solver - sensealg = QuadratureAdjoint(; autojacvec=ReverseDiffVJP()) # Faster than ReverseDiffAdjoint - end + sensealg = ReverseDiffAdjoint() end args = Vector{Any}() @@ -1010,6 +882,8 @@ function (nfmu::ME_NeuralFMU)(x_start::Union{Array{<:Real}, Nothing} = nfmu.x0, # ReverseDiff returns an array instead of an ODESolution, this needs to be corrected if isa(c.solution.states, TrackedArray) + + @assert !isnothing(saveat) "Keyword `saveat` is nothing, please provide the keyword." t = collect(saveat) u = c.solution.states @@ -1024,14 +898,14 @@ function (nfmu::ME_NeuralFMU)(x_start::Union{Array{<:Real}, Nothing} = nfmu.x0, end # ignore_derivatives - # stopCB (Opt B) + # stopCB stopCallback(nfmu, c, t_stop) return c.solution end function (nfmu::ME_NeuralFMU)(x0::Union{Array{<:Real}, Nothing}, - t::Real; - p=nothing) + t::Real; + p=nothing) c = nothing @@ -1039,6 +913,9 @@ function (nfmu::ME_NeuralFMU)(x0::Union{Array{<:Real}, Nothing}, end """ + + ToDo: Docstring for Arguments, Keyword arguments, ... + Evaluates the CS_NeuralFMU in the timespan given during construction or in a custum timespan from `t_start` to `t_stop` with a given time step size `t_step`. Via optional argument `reset`, the FMU is reset every time evaluation is started (default=`true`). @@ -1061,7 +938,7 @@ function (nfmu::CS_NeuralFMU{F, C})(inputFct, c, _ = prepareSolveFMU(nfmu.fmu, c, fmi2TypeCoSimulation, instantiate, freeInstance, terminate, reset, setup, parameters, t_start, t_stop, tolerance; cleanup=true) ts = collect(t_start:t_step:t_stop) - #c.skipNextDoStep = true # skip first fim2DoStep-call + model_input = inputFct.(ts) firstStep = true @@ -1080,7 +957,6 @@ function (nfmu::CS_NeuralFMU{F, C})(inputFct, y = nfmu.model(input) else # flattened, explicite parameters @assert !isnothing(nfmu.re) "Using explicite parameters without destructing the model." - y = nfmu.re(p)(input) end @@ -1095,6 +971,7 @@ function (nfmu::CS_NeuralFMU{F, C})(inputFct, c.solution.values = SavedValues{typeof(ts[1]), typeof(valueStack[1])}(ts, valueStack) + # [ToDo] check if this is still the case for current releases of related libraries # this is not possible in CS (pullbacks are sometimes called after the finished simulation), clean-up happens at the next call # c = finishSolveFMU(nfmu.fmu, c, freeInstance, terminate) @@ -1128,9 +1005,6 @@ function (nfmu::CS_NeuralFMU{Vector{F}, Vector{C}})(inputFct, solution = FMU2Solution(nothing) ts = collect(t_start:t_step:t_stop) - # for c in cs - # c.skipNextDoStep = true - # end model_input = inputFct.(ts) firstStep = true @@ -1151,7 +1025,7 @@ function (nfmu::CS_NeuralFMU{Vector{F}, Vector{C}})(inputFct, y = nfmu.model(input) else # flattened, explicite parameters @assert nfmu.re != nothing "Using explicite parameters without destructing the model." - #_p = collect(ForwardDiff.value(r) for r in p[1]) + if length(p) == 1 y = nfmu.re(p[1])(input) else @@ -1170,6 +1044,7 @@ function (nfmu::CS_NeuralFMU{Vector{F}, Vector{C}})(inputFct, solution.values = SavedValues{typeof(ts[1]), typeof(valueStack[1])}(ts, valueStack) + # [ToDo] check if this is still the case for current releases of related libraries # this is not possible in CS (pullbacks are sometimes called after the finished simulation), clean-up happens at the next call # cs = finishSolveFMU(nfmu.fmu, cs, freeInstance, terminate) @@ -1194,7 +1069,7 @@ function Flux.params(nfmu::CS_NeuralFMU; destructure::Bool=true) end end -function computeGradient(loss, params, gradient, chunk_size, multiObjective::Bool) +function computeGradient!(jac, loss, params, gradient::Symbol, chunk_size::Union{Symbol, Int}, multiObjective::Bool) if gradient == :ForwardDiff @@ -1202,11 +1077,10 @@ function computeGradient(loss, params, gradient, chunk_size, multiObjective::Boo if multiObjective conf = ForwardDiff.JacobianConfig(loss, params) - jac = ForwardDiff.jacobian(loss, params, conf) - return collect(jac[i,:] for i in 1:size(jac)[1]) + ForwardDiff.jacobian!(jac, loss, params, conf) else conf = ForwardDiff.GradientConfig(loss, params) - return [ForwardDiff.gradient(loss, params, conf)] + ForwardDiff.gradient!(jac, loss, params, conf) end elseif chunk_size == :auto_fmiflux @@ -1215,136 +1089,102 @@ function computeGradient(loss, params, gradient, chunk_size, multiObjective::Boo if multiObjective conf = ForwardDiff.JacobianConfig(loss, params, ForwardDiff.Chunk{min(chunk_size, length(params))}()); - jac = ForwardDiff.jacobian(loss, params, conf) - return collect(jac[i,:] for i in 1:size(jac)[1]) + ForwardDiff.jacobian!(jac, loss, params, conf) else conf = ForwardDiff.GradientConfig(loss, params, ForwardDiff.Chunk{min(chunk_size, length(params))}()); - return [ForwardDiff.gradient(loss, params, conf)] + ForwardDiff.gradient!(jac, loss, params, conf) end else if multiObjective conf = ForwardDiff.JacobianConfig(loss, params, ForwardDiff.Chunk{min(chunk_size, length(params))}()); - jac = ForwardDiff.jacobian(loss, params, conf) - return collect(jac[i,:] for i in 1:size(jac)[1]) + ForwardDiff.jacobian!(jac, loss, params, conf) else conf = ForwardDiff.GradientConfig(loss, params, ForwardDiff.Chunk{min(chunk_size, length(params))}()); - return [ForwardDiff.gradient(loss, params, conf)] + ForwardDiff.gradient!(jac, loss, params, conf) end end elseif gradient == :Zygote if multiObjective - jac = Zygote.jacobian(loss, params)[1] - return collect(jac[i,:] for i in 1:size(jac)[1]) + jac[:] = Zygote.jacobian(loss, params)[1] else - return [Zygote.gradient(loss, params)[1]] + jac[:] = Zygote.gradient(loss, params)[1] end elseif gradient == :ReverseDiff if multiObjective - jac = ReverseDiff.jacobian(loss, params) - return collect(jac[i,:] for i in 1:size(jac)[1]) + ReverseDiff.jacobian!(jac, loss, params) else - return [ReverseDiff.gradient(loss, params)] + ReverseDiff.gradient!(jac, loss, params) end elseif gradient == :FiniteDiff if multiObjective - @assert false "FiniteDiff is currently not implemented for multi-objective optimization. Please open an issue on FMIFlux.jl if this is needed." + FiniteDiff.finite_difference_jacobian!(jac, loss, params) else - return [FiniteDiff.finite_difference_gradient(loss, params)] + FiniteDiff.finite_difference_gradient!(jac, loss, params) end else @assert false "Unknown `gradient=$(gradient)`, supported are `:ForwardDiff`, `:Zygote`, `:FiniteDiff` and `:ReverseDiff`." end -end - -# WIP -function trainStep(loss, params, gradient, chunk_size, optim::Optim.AbstractOptimizer, printStep, proceed_on_assert, cb; state=nothing) + ### check gradient is valid - try - if isnothing(state) - state = initial_state(optim, options, d, initial_x) - end - - for j in 1:length(params) - - grad = computeGradient(loss, params[j], gradient, chunk_size) - - @assert !isnothing(grad) "Gradient nothing!" - - update_state!(d, state, optim) - - step = Flux.Optimise.apply!(optim, params[j], grad) - params[j] .-= step + # [Todo] Better! + grads = nothing + if multiObjective + grads = collect(jac[i,:] for i in 1:size(jac)[1]) + else + grads = [jac] + end - if printStep - @info "Grad: Min = $(min(abs.(grad)...)) Max = $(max(abs.(grad)...))" - @info "Step: Min = $(min(abs.(step)...)) Max = $(max(abs.(step)...))" - end - end + has_nan = any(collect(any(isnan.(grad)) for grad in grads)) + has_nothing = any(collect(any(isnothing.(grad)) for grad in grads)) || any(isnothing.(grads)) - catch e + if gradient != :ForwardDiff && (has_nan || has_nothing) + @warn "Gradient determination with $(gradient) failed, because gradient contains `NaNs` and/or `nothing`.\nThis might be because the FMU is throwing redundant events, which is currently not supported.\nTrying ForwardDiff as back-up.\nIf this message gets printed (almost) every step, consider using keyword `gradient=:ForwardDiff` to fix ForwardDiff as sensitivity system." + gradient = :ForwardDiff + computeGradient!(jac, loss, params, gradient, chunk_size, multiObjective) - if proceed_on_assert - @error "Training asserted, but continuing: $e" + if multiObjective + grads = collect(jac[i,:] for i in 1:size(jac)[1]) else - throw(e) + grads = [jac] end end - if cb != nothing - if isa(cb, AbstractArray) - for _cb in cb - _cb() - end - else - cb() - end - end + has_nan = any(collect(any(isnan.(grad)) for grad in grads)) + has_nothing = any(collect(any(isnothing.(grad)) for grad in grads)) || any(isnothing.(grads)) + + @assert !has_nan "Gradient determination with $(gradient) failed, because gradient contains `NaNs`.\nNo back-up options available." + @assert !has_nothing "Gradient determination with $(gradient) failed, because gradient contains `nothing`.\nNo back-up options available." + return nothing end -lk_OptimApply = ReentrantLock() -function trainStep(loss, params, gradient, chunk_size, optim::Flux.Optimise.AbstractOptimiser, printStep, proceed_on_assert, cb, multiObjective) +lk_TrainApply = ReentrantLock() +function trainStep(loss, params, gradient, chunk_size, optim::FMIFlux.AbstractOptimiser, printStep, proceed_on_assert, cb, multiObjective) - global lk_OptimApply + global lk_TrainApply try for j in 1:length(params) - grads = computeGradient(loss, params[j], gradient, chunk_size, multiObjective) - - has_nan = any(collect(any(isnan.(grad)) for grad in grads)) - has_nothing = any(collect(any(isnothing.(grad)) for grad in grads)) || any(isnothing.(grads)) - - if gradient != :ForwardDiff && (has_nan || has_nothing) - @warn "Gradient determination with $(gradient) failed, because gradient contains `NaNs` and/or `nothing`.\nTrying ForwardDiff as back-up.\nIf this message gets printed (almost) every step, consider using keyword `gradient=:ForwardDiff` to fix ForwardDiff as sensitivity system." - gradient = :ForwardDiff - grads = computeGradient(loss, params[j], gradient, chunk_size, multiObjective) - end + step = FMIFlux.apply!(optim, params[j]) - has_nan = any(collect(any(isnan.(grad)) for grad in grads)) - has_nothing = any(collect(any(isnothing.(grad)) for grad in grads)) || any(isnothing.(grads)) + lock(lk_TrainApply) do - @assert !has_nan "Gradient determination with $(gradient) failed, because gradient contains `NaNs`.\nNo back-up options available." - @assert !has_nothing "Gradient determination with $(gradient) failed, because gradient contains `nothing`.\nNo back-up options available." - - lock(lk_OptimApply) do - for grad in grads - step = Flux.Optimise.apply!(optim, params[j], grad) - params[j] .-= step - - if printStep - @info "Grad: Min = $(min(abs.(grad)...)) Max = $(max(abs.(grad)...))" - @info "Step: Min = $(min(abs.(step)...)) Max = $(max(abs.(step)...))" - end + params[j] .-= step + + if printStep + @info "Grad: Min = $(min(abs.(grad)...)) Max = $(max(abs.(grad)...))" + @info "Step: Min = $(min(abs.(step)...)) Max = $(max(abs.(step)...))" end + end end @@ -1397,11 +1237,6 @@ function train!(loss, params::Union{Flux.Params, Zygote.Params, AbstractVector{< @warn "train!(...): Multi-threading is set via flag `multiThreading=true`, but this Julia process does not have multiple threads. This will not result in a speed-up. Please spawn Julia in multi-thread mode to speed-up training." end - if length(params) <= 0 || length(params[1]) <= 0 - @warn "train!(...): Empty parameter array, training on an empty parameter array doesn't make sense." - return - end - _trainStep = (i,) -> trainStep(loss, params, gradient, chunk_size, optim, printStep, proceed_on_assert, cb, multiObjective) if multiThreading @@ -1420,193 +1255,34 @@ function train!(loss, neuralFMU::Union{ME_NeuralFMU, CS_NeuralFMU}, data, optim: train!(loss, params, data, optim; kwargs...) end -# checks gradient determination for all available sensitivity configurations, see: -# https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/ -using FMIImport.SciMLSensitivity -function checkSensalgs!(loss, neuralFMU::Union{ME_NeuralFMU, CS_NeuralFMU}; - gradients=(:ForwardDiff, :ReverseDiff, :Zygote), # :FiniteDiff is slow ... - max_msg_len=192, chunk_size=32, - OtD_autojacvecs=(false, true, TrackerVJP(), ZygoteVJP(), ReverseDiffVJP()), # EnzymeVJP() deadlocks in the current release xD - OtD_sensealgs=(BacksolveAdjoint, InterpolatingAdjoint, QuadratureAdjoint), - OtD_checkpointings=(true, false), - DtO_sensealgs=(ReverseDiffAdjoint, ZygoteAdjoint, TrackerAdjoint, ForwardDiffSensitivity), - multiObjective::Bool=false, - bestof::Int=2, - timeout_seconds::Real=60.0, - kwargs...) - - params = Flux.params(neuralFMU) - initial_sensalg = neuralFMU.fmu.executionConfig.sensealg - - best_timing = Inf - best_gradient = nothing - best_sensealg = nothing - - printstyled("Mode: Optimize-then-Discretize\n") - for gradient ∈ gradients - printstyled("\tGradient: $(gradient)\n") - - for sensealg ∈ OtD_sensealgs - printstyled("\t\tSensealg: $(sensealg)\n") - for checkpointing ∈ OtD_checkpointings - printstyled("\t\t\tCheckpointing: $(checkpointing)\n") - - if sensealg == QuadratureAdjoint && checkpointing - printstyled("\t\t\t\tQuadratureAdjoint doesn't implement checkpointing, skipping ...\n") - continue - end - - for autojacvec ∈ OtD_autojacvecs - printstyled("\t\t\t\tAutojacvec: $(autojacvec)\n") - - if sensealg ∈ (BacksolveAdjoint, InterpolatingAdjoint) - neuralFMU.fmu.executionConfig.sensealg = sensealg(; autojacvec=autojacvec, chunk_size=chunk_size, checkpointing=checkpointing) - else - neuralFMU.fmu.executionConfig.sensealg = sensealg(; autojacvec=autojacvec, chunk_size=chunk_size) - end - - call = () -> _tryrun(loss, params, gradient, chunk_size, 5, max_msg_len, multiObjective; timeout_seconds=timeout_seconds) - for i in 1:bestof - timing = call() - - if timing < best_timing - best_timing = timing - best_gradient = gradient - best_sensealg = neuralFMU.fmu.executionConfig.sensealg - end - end - - end - end - end +function train!(loss, params::Union{Flux.Params, Zygote.Params, AbstractVector{<:AbstractVector{<:Real}}}, data, optim::Flux.Optimise.AbstractOptimiser; gradient::Symbol=:ReverseDiff, chunk_size::Union{Integer, Symbol}=:auto_fmiflux, multiObjective::Bool=false, kwargs...) + if length(params) <= 0 || length(params[1]) <= 0 + @warn "train!(...): Empty parameter array, training on an empty parameter array doesn't make sense." + return end + + grad_buffer = nothing - printstyled("Mode: Discretize-then-Optimize\n") - for gradient ∈ gradients - printstyled("\tGradient: $(gradient)\n") - for sensealg ∈ DtO_sensealgs - printstyled("\t\tSensealg: $(sensealg)\n") - - if sensealg == ForwardDiffSensitivity - neuralFMU.fmu.executionConfig.sensealg = sensealg(; chunk_size=chunk_size, convert_tspan=true) - else - neuralFMU.fmu.executionConfig.sensealg = sensealg() - end - - call = () -> _tryrun(loss, params, gradient, chunk_size, 3, max_msg_len, multiObjective; timeout_seconds=timeout_seconds) - for i in 1:bestof - timing = call() - - if timing < best_timing - best_timing = timing - best_gradient = gradient - best_sensealg = neuralFMU.fmu.executionConfig.sensealg - end - end + if multiObjective + dim = loss(params[1]) - end + grad_buffer = zeros(Float64, length(params[1]), length(dim)) + else + grad_buffer = zeros(Float64, length(params[1])) end - neuralFMU.fmu.executionConfig.sensealg = initial_sensalg - - printstyled("------------------------------\nBest time: $(best_timing)\nBest gradient: $(best_gradient)\nBest sensealg: $(best_sensealg)\n", color=:blue) - - return nothing -end - -# Thanks to: -# https://discourse.julialang.org/t/help-writing-a-timeout-macro/16591/11 -function timeout(f, arg, seconds, fail) - tsk = @task f(arg...) - schedule(tsk) - Timer(seconds) do timer - istaskdone(tsk) || Base.throwto(tsk, InterruptException()) - end - try - fetch(tsk) - catch _; - fail - end + grad_fun! = (G, p) -> computeGradient!(G, loss, p, gradient, chunk_size, multiObjective) + _optim = FluxOptimiserWrapper(optim, grad_fun!, grad_buffer) + train!(loss, params, data, _optim; gradient=gradient, chunk_size=chunk_size, multiObjective=multiObjective, kwargs...) end -function runGrads(loss, params, gradient, chunk_size, multiObjective) - tstart = time() - grads = computeGradient(loss, params[1], gradient, chunk_size, multiObjective) - timing = time() - tstart - - if length(grads[1]) == 1 - grads = [grads] +function train!(loss, params::Union{Flux.Params, Zygote.Params, AbstractVector{<:AbstractVector{<:Real}}}, data, optim::Optim.AbstractOptimizer; gradient::Symbol=:ReverseDiff, chunk_size::Union{Integer, Symbol}=:auto_fmiflux, multiObjective::Bool=false, kwargs...) + if length(params) <= 0 || length(params[1]) <= 0 + @warn "train!(...): Empty parameter array, training on an empty parameter array doesn't make sense." + return end - - return grads, timing + + grad_fun! = (G, p) -> computeGradient!(G, loss, p, gradient, chunk_size, multiObjective) + _optim = OptimOptimiserWrapper(optim, grad_fun!, loss, params[1]) + train!(loss, params, data, _optim; gradient=gradient, chunk_size=chunk_size, multiObjective=multiObjective, kwargs...) end - -function _tryrun(loss, params, gradient, chunk_size, ts, max_msg_len, multiObjective::Bool=false; print_stdout::Bool=true, print_stderr::Bool=true, timeout_seconds::Real=60.0) - - spacing = "" - for t in ts - spacing *= "\t" - end - - message = "" - color = :black - timing = Inf - - original_stdout = stdout - original_stderr = stderr - (rd_stdout, wr_stdout) = redirect_stdout(); - (rd_stderr, wr_stderr) = redirect_stderr(); - - try - - #grads, timing = timeout(runGrads, (loss, params, gradient, chunk_size, multiObjective), timeout_seconds, ([Inf], -1.0)) - grads, timing = runGrads(loss, params, gradient, chunk_size, multiObjective) - - if timing == -1.0 - message = spacing * "TIMEOUT\n" - color = :red - else - val = collect(sum(abs.(grad)) for grad in grads) - message = spacing * "SUCCESS | $(round(timing; digits=2))s | GradAbsSum: $(round.(val; digits=6))\n" - color = :green - end - - catch e - msg = "$(e)" - if length(msg) > max_msg_len - msg = msg[1:max_msg_len] * "..." - end - - message = spacing * "$(msg)\n" - color = :red - end - - redirect_stdout(original_stdout) - redirect_stderr(original_stderr) - close(wr_stdout) - close(wr_stderr) - - if print_stdout - msg = read(rd_stdout, String) - if length(msg) > 0 - if length(msg) > max_msg_len - msg = msg[1:max_msg_len] * "..." - end - printstyled(spacing * "STDOUT: $(msg)\n", color=:yellow) - end - end - - if print_stderr - msg = read(rd_stderr, String) - if length(msg) > 0 - if length(msg) > max_msg_len - msg = msg[1:max_msg_len] * "..." - end - printstyled(spacing * "STDERR: $(msg)\n", color=:yellow) - end - end - - printstyled(message, color=color) - - return timing -end \ No newline at end of file diff --git a/src/optimiser.jl b/src/optimiser.jl index ac96e242..23367ded 100644 --- a/src/optimiser.jl +++ b/src/optimiser.jl @@ -3,40 +3,74 @@ # Licensed under the MIT license. See LICENSE file in the project root for details. # -import Flux.Optimisers - -struct SoftStart{T} <: Optimisers.AbstractRule - minx::T - maxx::T - steps::UInt - - function SoftStart{T}(minx::T, steps::UInt; maxx::T=1.0) where {T} - inst = new() - inst.minx = minx - inst.maxx = maxx - inst.steps = steps - return inst - end +import Flux +import Optim + +abstract type AbstractOptimiser end + +### Optim.jl ### + +struct OptimOptimiserWrapper{G} <: AbstractOptimiser + optim::Optim.AbstractOptimizer + grad_fun!::G + + state::Union{Optim.AbstractOptimizerState, Nothing} + d::Union{Optim.OnceDifferentiable, Nothing} + options + + function OptimOptimiserWrapper(optim::Optim.AbstractOptimizer, grad_fun!::G, loss, params) where {G} + options = Optim.Options(iterations=1) + autodiff = :finite + inplace = true - function SoftStart(minx::T, steps::UInt; maxx::T=1.0) - return SoftStart{fmi2Real}(minx, steps; maxx=maxx) + d = Optim.promote_objtype(optim, params, autodiff, inplace, loss, grad_fun!) + state = Optim.initial_state(optim, options, d, params) + + return new{G}(optim, grad_fun!, state, d, options) end + end -export SoftStart +export OptimOptimiserWrapper -function Optimisers.apply!(o::SoftStart, state, x, x̄) - step = state +function apply!(optim::OptimOptimiserWrapper, params) - if step > o.steps - newx̄ = o.maxx - else - newx̄ = o.minx * ((o.maxx/o.minx)^(1.0/o.steps*step)) + res = Optim.optimize(optim.d, params, optim.optim, optim.options, optim.state) + step = params .- Optim.minimizer(res) + + return step +end + +### Flux.Optimisers ### + +struct FluxOptimiserWrapper{G} <: AbstractOptimiser + optim::Flux.Optimise.AbstractOptimiser + grad_fun!::G + grad_buffer::Union{AbstractVector{Float64}, AbstractMatrix{Float64}} + multiGrad::Bool + + function FluxOptimiserWrapper(optim::Flux.Optimise.AbstractOptimiser, grad_fun!::G, grad_buffer::AbstractVector{Float64}) where {G} + return new{G}(optim, grad_fun!, grad_buffer, false) + end + + function FluxOptimiserWrapper(optim::Flux.Optimise.AbstractOptimiser, grad_fun!::G, grad_buffer::AbstractMatrix{Float64}) where {G} + return new{G}(optim, grad_fun!, grad_buffer, true) end - nextstate = step + 1 - return nextstate, newx̄ end +export FluxOptimiserWrapper -function Optimisers.init(o::SoftStart, x::AbstractArray) - return 0 -end \ No newline at end of file +function apply!(optim::FluxOptimiserWrapper, params) + + optim.grad_fun!(optim.grad_buffer, params) + + if optim.multiGrad + return collect(Flux.Optimise.apply!(optim.optim, params, optim.grad_buffer[:,i]) for i in 1:size(optim.grad_buffer)[2]) + else + return Flux.Optimise.apply!(optim.optim, params, optim.grad_buffer) + end +end + +### generic FMIFlux.AbstractOptimiser ### + + + \ No newline at end of file diff --git a/test/batching.jl b/test/batching.jl index df176fb2..98e5d558 100644 --- a/test/batching.jl +++ b/test/batching.jl @@ -29,7 +29,7 @@ import FMI.FMIImport: unsense # loss function for training function losssum_single(p) global problem, x0, posData - solution = problem(x0; p=p, showProgress=true) + solution = problem(x0; p=p, showProgress=true, saveat=tData) if !solution.success return Inf @@ -42,7 +42,7 @@ end function losssum_multi(p) global problem, x0, posData - solution = problem(x0; p=p, showProgress=true) + solution = problem(x0; p=p, showProgress=true, saveat=tData) if !solution.success return [Inf, Inf] @@ -72,12 +72,18 @@ end numStates = fmiGetNumberOfStates(fmu) +c1 = CacheLayer() +c2 = CacheRetrieveLayer(c1) + # the "Chain" for training -net = Chain(x -> fmu(;x=x), +net = Chain(x -> fmu(;x=x, dx_refs=:all), + dx -> c1(dx), Dense(numStates, 12, tanh), - Dense(12, numStates, identity)) + Dense(12, 1, identity), + dx -> c2([1], dx[1], [])) -problem = ME_NeuralFMU(fmu, net, (t_start, t_stop); saveat=tData) +solver = Tsit5() +problem = ME_NeuralFMU(fmu, net, (t_start, t_stop), solver; saveat=tData) @test problem != nothing solutionBefore = problem(x0) @@ -93,9 +99,9 @@ optim = Adam(1e-3) FMIFlux.train!(losssum_single, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(losssum_single, p_net), gradient=GRADIENT) # multi objective -lastLoss = sum(losssum_multi(p_net[1])) -optim = Adam(1e-3) -FMIFlux.train!(losssum_multi, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(losssum_multi, p_net), gradient=GRADIENT, multiObjective=true) +# lastLoss = sum(losssum_multi(p_net[1])) +# optim = Adam(1e-3) +# FMIFlux.train!(losssum_multi, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(losssum_multi, p_net), gradient=GRADIENT, multiObjective=true) # check results solutionAfter = problem(x0) diff --git a/test/fmu_params.jl b/test/fmu_params.jl index 732bd4e1..aa40aefd 100644 --- a/test/fmu_params.jl +++ b/test/fmu_params.jl @@ -36,8 +36,9 @@ p = fmi2GetReal(c, p_refs) # loss function for training function losssum(p) - global problem, x0, posData - solution = problem(x0; p=p, showProgress=true) + #@info "$p" + global problem, x0, posData, solution + solution = problem(x0; p=p, showProgress=true, saveat=tData) if !solution.success return Inf @@ -77,10 +78,11 @@ net = Chain(FMUParameterRegistrator(fmu, p_refs, p), optim = Adam(1e-3) solver = Tsit5() -problem = ME_NeuralFMU(fmu, net, (t_start, t_stop), solver; saveat=tData) +problem = ME_NeuralFMU(fmu, net, (t_start, t_stop), solver) +problem.modifiedState = false @test problem != nothing -solutionBefore = problem(x0) +solutionBefore = problem(x0; saveat=tData) @test solutionBefore.success @test length(solutionBefore.states.t) == length(tData) @test solutionBefore.states.t[1] == t_start @@ -89,17 +91,21 @@ solutionBefore = problem(x0) # train it ... p_net = Flux.params(problem) @test length(p_net) == 1 -@test length(p_net[1]) == 12 +@test length(p_net[1]) == 7 iterCB = 0 lastLoss = losssum(p_net[1]) @info "Start-Loss for net: $lastLoss" -@warn "FMU parameter tests disabled." -# FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), gradient=GRADIENT) +# [ToDo] Discontinuous system? +j_fin = FiniteDiff.finite_difference_gradient(losssum, p_net[1]) +j_fwd = ForwardDiff.gradient(losssum, p_net[1]) +j_rwd = ReverseDiff.gradient(losssum, p_net[1]) + +FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), gradient=GRADIENT) # check results -solutionAfter = problem(x0) +solutionAfter = problem(x0; saveat=tData) @test solutionAfter.success @test length(solutionAfter.states.t) == length(tData) @test solutionAfter.states.t[1] == t_start diff --git a/test/hybrid_CS.jl b/test/hybrid_CS.jl index e0541868..ecae8c5b 100644 --- a/test/hybrid_CS.jl +++ b/test/hybrid_CS.jl @@ -66,7 +66,7 @@ net = Chain(u -> myFMU(;u_refs=myFMU.modelDescription.inputValueReferences, u=u, Dense(16, 16, tanh; init=Flux.identity_init), Dense(16, numOutputs; init=Flux.identity_init)) -problem = CS_NeuralFMU(myFMU, net, (t_start, t_stop); saveat=tData) +problem = CS_NeuralFMU(myFMU, net, (t_start, t_stop)) @test problem != nothing # train it ... diff --git a/test/hybrid_ME.jl b/test/hybrid_ME.jl index 70c2fa65..a285e088 100644 --- a/test/hybrid_ME.jl +++ b/test/hybrid_ME.jl @@ -16,13 +16,14 @@ t_stop = 5.0 tData = t_start:t_step:t_stop # generate training data -realFMU = fmiLoad("SpringFrictionPendulum1D", EXPORTINGTOOL, EXPORTINGVERSION; type=fmi2TypeCoSimulation) -realSimData = fmiSimulateCS(realFMU, (t_start, t_stop); recordValues=["mass.s", "mass.v"], saveat=tData) +fmu = fmiLoad("SpringFrictionPendulum1D", EXPORTINGTOOL, EXPORTINGVERSION; type=fmi2TypeCoSimulation) +realSimData = fmiSimulateCS(fmu, (t_start, t_stop); recordValues=["mass.s", "mass.v"], saveat=tData) x0 = collect(realSimData.values.saveval[1]) @test x0 == [0.5, 0.0] +fmiUnload(fmu) # load FMU for NeuralFMU -myFMU = fmiLoad("SpringPendulum1D", EXPORTINGTOOL, EXPORTINGVERSION; type=fmi2TypeModelExchange) +fmu = fmiLoad("SpringPendulum1D", EXPORTINGTOOL, EXPORTINGVERSION; type=fmi2TypeModelExchange) # setup traing data velData = fmi2GetSolutionValue(realSimData, "mass.v") @@ -30,7 +31,7 @@ velData = fmi2GetSolutionValue(realSimData, "mass.v") # loss function for training function losssum(p) global problem, x0, posData #, solution - solution = problem(x0; p=p, showProgress=true) + solution = problem(x0; p=p, showProgress=true, saveat=tData) if !solution.success return Inf @@ -57,7 +58,7 @@ function callb(p) end end -numStates = fmiGetNumberOfStates(myFMU) +numStates = fmiGetNumberOfStates(fmu) # some NeuralFMU setups nets = [] @@ -67,87 +68,92 @@ c2 = CacheRetrieveLayer(c1) c3 = CacheLayer() c4 = CacheRetrieveLayer(c3) +init = Flux.glorot_uniform +getVRs = [fmi2StringToValueReference(fmu, "mass.s")] +numGetVRs = length(getVRs) +y = zeros(fmi2Real, numGetVRs) +setVRs = [fmi2StringToValueReference(fmu, "mass.m")] +numSetVRs = length(setVRs) + # 1. default ME-NeuralFMU (learn dynamics and states, almost-neutral setup, parameter count << 100) -net = Chain(Dense(numStates, numStates, identity; init=Flux.identity_init), - x -> myFMU(;x=x), - c3, - Dense(numStates, numStates, identity; init=Flux.identity_init), - x -> c4([1], x[2], [])) +net = Chain(x -> c1(x), + Dense(numStates, 1, identity; init=init), + x -> c2([], x[1], [1]), + x -> fmu(;x=x, dx_refs=:all), + x -> c3(x), + Dense(numStates, 1, identity; init=init), + x -> c4([1], x[1], [])) push!(nets, net) # 2. default ME-NeuralFMU (learn dynamics) -net = Chain(x -> myFMU(;x=x), +net = Chain(x -> fmu(;x=x, dx_refs=:all), x -> c1(x), - Dense(numStates, 16, identity; init=Flux.identity_init), - Dense(16, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([1], x[2], [])) + Dense(numStates, 16, tanh; init=init), + Dense(16, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([1], x[1], [])) push!(nets, net) # 3. default ME-NeuralFMU (learn states) net = Chain(x -> c1(x), - Dense(numStates, 16, identity, init=Flux.identity_init), - Dense(16, 16, identity, init=Flux.identity_init), - Dense(16, numStates, identity, init=Flux.identity_init), - x -> c2([1], x[2], []), - x -> myFMU(;x=x)) + Dense(numStates, 16, tanh, init=init), + Dense(16, 16, tanh, init=init), + Dense(16, 1, identity, init=init), + x -> c2([], x[1], [1]), + x -> fmu(;x=x, dx_refs=:all)) push!(nets, net) # 4. default ME-NeuralFMU (learn dynamics and states) net = Chain(x -> c1(x), - Dense(numStates, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([1], x[2], []), - x -> myFMU(;x=x), + Dense(numStates, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([], x[1], [1]), + x -> fmu(;x=x, dx_refs=:all), x -> c3(x), - Dense(numStates, 16, identity, init=Flux.identity_init), - Dense(16, 16, identity, init=Flux.identity_init), - Dense(16, numStates, identity, init=Flux.identity_init), - x -> c4([1], x[2], [])) + Dense(numStates, 16, tanh, init=init), + Dense(16, 16, tanh, init=init), + Dense(16, 1, identity, init=init), + x -> c4([1], x[1], [])) push!(nets, net) # 5. NeuralFMU with hard setting time to 0.0 -net = Chain(states -> myFMU(;x=states, t=0.0), +net = Chain(states -> fmu(;x=states, t=0.0, dx_refs=:all), x -> c1(x), - Dense(numStates, 8, identity; init=Flux.identity_init), - Dense(8, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([1], x[2], [])) + Dense(numStates, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([1], x[1], [])) push!(nets, net) # 6. NeuralFMU with additional getter -getVRs = [fmi2StringToValueReference(myFMU, "mass.s")] -numGetVRs = length(getVRs) -net = Chain(x -> myFMU(;x=x, y_refs=getVRs), +net = Chain(x -> fmu(;x=x, y_refs=getVRs, y=y, dx_refs=:all), x -> c1(x), - Dense(numStates+numGetVRs, 8, identity; init=Flux.identity_init), - Dense(8, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([2], x[2], [])) + Dense(numStates+numGetVRs, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([2], x[1], [])) push!(nets, net) # 7. NeuralFMU with additional setter -setVRs = [fmi2StringToValueReference(myFMU, "mass.m")] -numSetVRs = length(setVRs) -net = Chain(x -> myFMU(;x=x, u_refs=setVRs, u=[1.1]), +net = Chain(x -> fmu(;x=x, u_refs=setVRs, u=[1.1], dx_refs=:all), x -> c1(x), - Dense(numStates, 8, identity; init=Flux.identity_init), - Dense(8, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([1], x[2], [])) + Dense(numStates, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([1], x[1], [])) push!(nets, net) # 8. NeuralFMU with additional setter and getter -net = Chain(x -> myFMU(;x=x, u_refs=setVRs, u=[1.1], y_refs=getVRs), +net = Chain(x -> fmu(;x=x, u_refs=setVRs, u=[1.1], y_refs=getVRs, y=y, dx_refs=:all), x -> c1(x), - Dense(numStates+numGetVRs, 8, identity; init=Flux.identity_init), - Dense(8, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([2], x[2], [])) + Dense(numStates+numGetVRs, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([2], x[1], [])) push!(nets, net) # 9. an empty NeuralFMU (this does only make sense for debugging) -net = Chain(x -> myFMU(;x=x)) +net = Chain(x -> fmu(x=x, dx_refs=:all)) push!(nets, net) for i in 1:length(nets) @@ -158,27 +164,32 @@ for i in 1:length(nets) solver = Tsit5() net = nets[i] - problem = ME_NeuralFMU(myFMU, net, (t_start, t_stop), solver; saveat=tData) + problem = ME_NeuralFMU(fmu, net, (t_start, t_stop), solver) @test problem != nothing - solutionBefore = problem(x0) + # train it ... + p_net = Flux.params(problem) + @test length(p_net) == 1 + + solutionBefore = problem(x0; p=p_net[1], saveat=tData) if solutionBefore.success @test length(solutionBefore.states.t) == length(tData) @test solutionBefore.states.t[1] == t_start @test solutionBefore.states.t[end] == t_stop end - # train it ... - p_net = Flux.params(problem) - @test length(p_net) == 1 - iterCB = 0 lastLoss = losssum(p_net[1]) @info "Start-Loss for net #$i: $lastLoss" + + if length(p_net[1]) == 0 + @info "The following warning is not an issue, because training on zero parameters must throw a warning:" + end + FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), gradient=GRADIENT) # check results - solutionAfter = problem(x0) + solutionAfter = problem(x0; p=p_net[1], saveat=tData) if solutionAfter.success @test length(solutionAfter.states.t) == length(tData) @test solutionAfter.states.t[1] == t_start @@ -187,7 +198,6 @@ for i in 1:length(nets) end end -@test length(myFMU.components) <= 1 +@test length(fmu.components) <= 1 -fmiUnload(realFMU) -fmiUnload(myFMU) +fmiUnload(fmu) diff --git a/test/hybrid_ME_dis.jl b/test/hybrid_ME_dis.jl index 22079734..6464e30d 100644 --- a/test/hybrid_ME_dis.jl +++ b/test/hybrid_ME_dis.jl @@ -16,19 +16,19 @@ t_stop = 3.0 tData = t_start:t_step:t_stop # generate training data -fmu = fmiLoad("SpringFrictionPendulum1D", EXPORTINGTOOL, EXPORTINGVERSION; type=:ME) -pdict = Dict("mass.m" => 1.3) -realSimData = fmiSimulate(fmu, (t_start, t_stop); parameters=pdict, recordValues=["mass.s", "mass.v"], saveat=tData) +fmu = fmiLoad("BouncingBall", "ModelicaReferenceFMUs", "0.0.25"; type=:ME) +pdict = Dict("g" => 9.0) +realSimData = fmiSimulate(fmu, (t_start, t_stop); parameters=pdict, recordValues=["h", "v"], saveat=tData) x0 = collect(realSimData.values.saveval[1]) -@test x0 == [0.5, 0.0] +@test x0 == [1.0, 0.0] # setup traing data -velData = fmi2GetSolutionValue(realSimData, "mass.v") +velData = fmi2GetSolutionValue(realSimData, "v") # loss function for training function losssum(p) global problem, x0, posData - solution = problem(x0; p=p) + solution = problem(x0; p=p, saveat=tData) if !solution.success return Inf @@ -52,12 +52,12 @@ function callb(p) @info "[$(iterCB)] Loss: $loss" # This test condition is not good, because when the FMU passes an event, the error might increase. - @test (loss < lastLoss) && (loss != lastLoss) + @test loss <= lastLoss lastLoss = loss end end -vr = fmi2StringToValueReference(fmu, "mass.m") +vr = fmi2StringToValueReference(fmu, "g") numStates = fmiGetNumberOfStates(fmu) @@ -69,116 +69,123 @@ c2 = CacheRetrieveLayer(c1) c3 = CacheLayer() c4 = CacheRetrieveLayer(c3) +init = Flux.glorot_uniform +getVRs = [fmi2StringToValueReference(fmu, "h")] +y = zeros(fmi2Real, length(getVRs)) +numGetVRs = length(getVRs) +setVRs = [fmi2StringToValueReference(fmu, "v")] +numSetVRs = length(setVRs) + # 1. default ME-NeuralFMU (learn dynamics and states, almost-neutral setup, parameter count << 100) net = Chain(x -> c1(x), - Dense(numStates, numStates, identity; init=Flux.identity_init), - x -> c2([1], x[2], []), - x -> fmu(;x=x), + Dense(numStates, 1, identity; init=init), + x -> c2([], x[1], [1]), + x -> fmu(;x=x, dx_refs=:all), x -> c3(x), - Dense(numStates, numStates, identity; init=Flux.identity_init), - x -> c4([1], x[2], [])) + Dense(numStates, 1, identity; init=init), + x -> c4([1], x[1], [])) push!(nets, net) # 2. default ME-NeuralFMU (learn dynamics) -net = Chain(x -> fmu(;x=x), +net = Chain(x -> fmu(;x=x, dx_refs=:all), x -> c1(x), - Dense(numStates, 16, identity; init=Flux.identity_init), - Dense(16, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([1], x[2], [])) + Dense(numStates, 16, tanh; init=init), + Dense(16, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([1], x[1], [])) push!(nets, net) # 3. default ME-NeuralFMU (learn states) net = Chain(x -> c1(x), - Dense(numStates, 16, identity, init=Flux.identity_init), - Dense(16, 16, identity, init=Flux.identity_init), - Dense(16, numStates, identity, init=Flux.identity_init), - x -> c2([1], x[2], []), - x -> fmu(;x=x)) + Dense(numStates, 16, tanh, init=init), + Dense(16, 16, tanh, init=init), + Dense(16, 1, identity, init=init), + x -> c2([], x[1], [1]), + x -> fmu(;x=x, dx_refs=:all)) push!(nets, net) # 4. default ME-NeuralFMU (learn dynamics and states) net = Chain(x -> c1(x), - Dense(numStates, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([1], x[2], []), - x -> fmu(;x=x), + Dense(numStates, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([], x[1], [1]), + x -> fmu(;x=x, dx_refs=:all), x -> c3(x), - Dense(numStates, 16, identity, init=Flux.identity_init), - Dense(16, 16, identity, init=Flux.identity_init), - Dense(16, numStates, identity, init=Flux.identity_init), - x -> c4([1], x[2], [])) + Dense(numStates, 16, tanh, init=init), + Dense(16, 16, tanh, init=init), + Dense(16, 1, identity, init=init), + x -> c4([1], x[1], [])) push!(nets, net) # 5. NeuralFMU with hard setting time to 0.0 -net = Chain(states -> fmu(;x=states, t=0.0), +net = Chain(states -> fmu(;x=states, t=0.0, dx_refs=:all), x -> c1(x), - Dense(numStates, 8, identity; init=Flux.identity_init), - Dense(8, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([1], x[2], [])) + Dense(numStates, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([1], x[1], [])) push!(nets, net) # 6. NeuralFMU with additional getter -getVRs = [fmi2StringToValueReference(fmu, "mass.s")] -numGetVRs = length(getVRs) -net = Chain(x -> fmu(;x=x, y_refs=getVRs), +net = Chain(x -> fmu(;x=x, y_refs=getVRs, y=y, dx_refs=:all), x -> c1(x), - Dense(numStates+numGetVRs, 8, identity; init=Flux.identity_init), - Dense(8, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([2], x[2], [])) + Dense(numStates+numGetVRs, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([2], x[1], [])) push!(nets, net) # 7. NeuralFMU with additional setter -setVRs = [fmi2StringToValueReference(fmu, "mass.m")] -numSetVRs = length(setVRs) -net = Chain(x -> fmu(;x=x, u_refs=setVRs, u=[1.1]), +net = Chain(x -> fmu(;x=x, u_refs=setVRs, u=[1.1], dx_refs=:all), x -> c1(x), - Dense(numStates, 8, identity; init=Flux.identity_init), - Dense(8, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([1], x[2], [])) + Dense(numStates, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([1], x[1], [])) push!(nets, net) # 8. NeuralFMU with additional setter and getter -net = Chain(x -> fmu(;x=x, u_refs=setVRs, u=[1.1], y_refs=getVRs), +net = Chain(x -> fmu(;x=x, u_refs=setVRs, u=[1.1], y_refs=getVRs, y=y, dx_refs=:all), x -> c1(x), - Dense(numStates+numGetVRs, 8, identity; init=Flux.identity_init), - Dense(8, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([2], x[2], [])) + Dense(numStates+numGetVRs, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([2], x[1], [])) +push!(nets, net) + +# 9. an empty NeuralFMU (this does only make sense for debugging) +net = Chain(x -> fmu(x=x, dx_refs=:all)) push!(nets, net) for i in 1:length(nets) @testset "Net setup $(i)/$(length(nets))" begin global nets, problem, lastLoss, iterCB - optim = Adam(1e-4) + optim = Adam(1e-6) solver = Tsit5() net = nets[i] - problem = ME_NeuralFMU(fmu, net, (t_start, t_stop), solver; saveat=tData) + problem = ME_NeuralFMU(fmu, net, (t_start, t_stop), solver) + + # train it ... + p_net = Flux.params(problem) @test problem !== nothing - solutionBefore = problem(x0) + solutionBefore = problem(x0; p=p_net[1], saveat=tData) if solutionBefore.success @test length(solutionBefore.states.t) == length(tData) @test solutionBefore.states.t[1] == t_start @test solutionBefore.states.t[end] == t_stop end - # train it ... - p_net = Flux.params(problem) - iterCB = 0 lastLoss = losssum(p_net[1]) @info "[ $(iterCB)] Loss: $lastLoss" FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), gradient=GRADIENT) # check results - solutionAfter = problem(x0) + solutionAfter = problem(x0; p=p_net[1], saveat=tData) if solutionAfter.success @test length(solutionAfter.states.t) == length(tData) @test solutionAfter.states.t[1] == t_start diff --git a/test/multi.jl b/test/multi.jl index 3f4d4caa..0c94cfc7 100644 --- a/test/multi.jl +++ b/test/multi.jl @@ -83,7 +83,7 @@ net = Chain( Dense(16, length(fmus[1].modelDescription.outputValueReferences); init=Flux.identity_init), ) -problem = CS_NeuralFMU(fmus, net, (t_start, t_stop); saveat=tData) +problem = CS_NeuralFMU(fmus, net, (t_start, t_stop)) @test problem != nothing solutionBefore = problem(extForce, t_step) diff --git a/test/multi_threading.jl b/test/multi_threading.jl index 69b3801e..cd5c7402 100644 --- a/test/multi_threading.jl +++ b/test/multi_threading.jl @@ -17,6 +17,7 @@ tData = t_start:t_step:t_stop # generate training data realFMU = fmiLoad("SpringFrictionPendulum1D", EXPORTINGTOOL, EXPORTINGVERSION; type=:ME) + pdict = Dict("mass.m" => 1.3) realSimData = fmiSimulate(realFMU, (t_start, t_stop); parameters=pdict, recordValues=["mass.s", "mass.v"], saveat=tData) x0 = collect(realSimData.values.saveval[1]) @@ -31,7 +32,7 @@ velData = fmi2GetSolutionValue(realSimData, "mass.v") # loss function for training function losssum(p) global problem, x0, posData - solution = problem(x0; p=p, showProgress=true) + solution = problem(x0; p=p, showProgress=true, saveat=tData) if !solution.success return Inf @@ -72,15 +73,15 @@ c4 = CacheRetrieveLayer(c3) # 1. Discontinuous ME-NeuralFMU (learn dynamics and states) net = Chain(x -> c1(x), - Dense(numStates, 16, identity; init=Flux.identity_init), - Dense(16, numStates, identity; init=Flux.identity_init), - x -> c2([1], x[2], []), - x -> realFMU(;x=x), + Dense(numStates, 16, identity), + Dense(16, 1, identity), + x -> c2([], x[1], [1]), + x -> realFMU(;x=x, dx_refs=:all), x -> c3(x), - Dense(numStates, 16, identity, init=Flux.identity_init), - Dense(16, 16, identity, init=Flux.identity_init), - Dense(16, numStates, identity, init=Flux.identity_init), - x -> c4([1], x[2], [])) + Dense(numStates, 16, identity), + Dense(16, 16, identity), + Dense(16, 1, identity), + x -> c4([1], x[1], [])) push!(nets, net) for i in 1:length(nets) @@ -112,34 +113,36 @@ for i in 1:length(nets) p_net[1][:] = p_start[:] lastLoss = startLoss st = time() - optim = Adam(1e-4) + optim = Adam(1e-6) FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), multiThreading=false, gradient=GRADIENT) - dt = round(time()-st; digits=1) + dt = round(time()-st; digits=2) @info "Training time single threaded (not pre-compiled): $(dt)s" p_net[1][:] = p_start[:] lastLoss = startLoss st = time() - optim = Adam(1e-4) + optim = Adam(1e-6) FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), multiThreading=false, gradient=GRADIENT) - dt = round(time()-st; digits=1) + dt = round(time()-st; digits=2) @info "Training time single threaded (pre-compiled): $(dt)s" - p_net[1][:] = p_start[:] - lastLoss = startLoss - st = time() - optim = Adam(1e-4) - FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), multiThreading=true, gradient=GRADIENT) - dt = round(time()-st; digits=1) - @info "Training time multi threaded (not pre-compiled): $(dt)s" - - p_net[1][:] = p_start[:] - lastLoss = startLoss - st = time() - optim = Adam(1e-4) - FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), multiThreading=true, gradient=GRADIENT) - dt = round(time()-st; digits=1) - @info "Training time multi threaded (pre-compiled): $(dt)s" + # [ToDo] currently not implemented + + # p_net[1][:] = p_start[:] + # lastLoss = startLoss + # st = time() + # optim = Adam(1e-6) + # FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), multiThreading=true, gradient=GRADIENT) + # dt = round(time()-st; digits=2) + # @info "Training time multi threaded x$(Threads.nthreads()) (not pre-compiled): $(dt)s" + + # p_net[1][:] = p_start[:] + # lastLoss = startLoss + # st = time() + # optim = Adam(1e-6) + # FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), multiThreading=true, gradient=GRADIENT) + # dt = round(time()-st; digits=2) + # @info "Training time multi threaded x$(Threads.nthreads()) (pre-compiled): $(dt)s" # check results solutionAfter = problem(x0) diff --git a/test/optim.jl b/test/optim.jl new file mode 100644 index 00000000..ebc5e351 --- /dev/null +++ b/test/optim.jl @@ -0,0 +1,198 @@ +# +# Copyright (c) 2021 Tobias Thummerer, Lars Mikelsons +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +using FMI +using Flux +using DifferentialEquations: Tsit5 +using FMIFlux.Optim + +import Random +Random.seed!(1234); + +t_start = 0.0 +t_step = 0.01 +t_stop = 5.0 +tData = t_start:t_step:t_stop + +# generate training data +fmu = fmiLoad("SpringFrictionPendulum1D", EXPORTINGTOOL, EXPORTINGVERSION; type=fmi2TypeCoSimulation) +realSimData = fmiSimulateCS(fmu, (t_start, t_stop); recordValues=["mass.s", "mass.v"], saveat=tData) +x0 = collect(realSimData.values.saveval[1]) +@test x0 == [0.5, 0.0] +fmiUnload(fmu) + +# load FMU for NeuralFMU +fmu = fmiLoad("SpringPendulum1D", EXPORTINGTOOL, EXPORTINGVERSION; type=fmi2TypeModelExchange) + +# setup traing data +velData = fmi2GetSolutionValue(realSimData, "mass.v") + +# loss function for training +function losssum(p) + global problem, x0 + solution = problem(x0; p=p, showProgress=true, saveat=tData) + + if !solution.success + return Inf + end + + velNet = fmi2GetSolutionState(solution, 2; isIndex=true) + + return Flux.Losses.mse(velNet, velData) +end + +# callback function for training +global iterCB = 0 +global lastLoss = 0.0 +function callb(p) + global iterCB += 1 + global lastLoss + + if iterCB % 5 == 0 + loss = losssum(p[1]) + @info "[$(iterCB)] Loss: $loss" + @test loss < lastLoss + lastLoss = loss + end +end + +numStates = fmiGetNumberOfStates(fmu) + +# some NeuralFMU setups +nets = [] + +c1 = CacheLayer() +c2 = CacheRetrieveLayer(c1) +c3 = CacheLayer() +c4 = CacheRetrieveLayer(c3) + +init = Flux.glorot_uniform +getVRs = [fmi2StringToValueReference(fmu, "mass.s")] +numGetVRs = length(getVRs) +y = zeros(fmi2Real, numGetVRs) +setVRs = [fmi2StringToValueReference(fmu, "mass.m")] +numSetVRs = length(setVRs) + +# 1. default ME-NeuralFMU (learn dynamics and states, almost-neutral setup, parameter count << 100) +net = Chain(x -> c1(x), + Dense(numStates, 1, identity; init=init), + x -> c2([], x[1], [1]), + x -> fmu(;x=x, dx_refs=:all), + x -> c3(x), + Dense(numStates, 1, identity; init=init), + x -> c4([1], x[1], [])) +push!(nets, net) + +# 2. default ME-NeuralFMU (learn dynamics) +net = Chain(x -> fmu(;x=x, dx_refs=:all), + x -> c1(x), + Dense(numStates, 16, tanh; init=init), + Dense(16, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([1], x[1], [])) +push!(nets, net) + +# 3. default ME-NeuralFMU (learn states) +net = Chain(x -> c1(x), + Dense(numStates, 16, tanh, init=init), + Dense(16, 16, tanh, init=init), + Dense(16, 1, identity, init=init), + x -> c2([], x[1], [1]), + x -> fmu(;x=x, dx_refs=:all)) +push!(nets, net) + +# 4. default ME-NeuralFMU (learn dynamics and states) +net = Chain(x -> c1(x), + Dense(numStates, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([], x[1], [1]), + x -> fmu(;x=x, dx_refs=:all), + x -> c3(x), + Dense(numStates, 16, tanh, init=init), + Dense(16, 16, tanh, init=init), + Dense(16, 1, identity, init=init), + x -> c4([1], x[1], [])) +push!(nets, net) + +# 5. NeuralFMU with hard setting time to 0.0 +net = Chain(states -> fmu(;x=states, t=0.0, dx_refs=:all), + x -> c1(x), + Dense(numStates, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([1], x[1], [])) +push!(nets, net) + +# 6. NeuralFMU with additional getter +net = Chain(x -> fmu(;x=x, y_refs=getVRs, y=y, dx_refs=:all), + x -> c1(x), + Dense(numStates+numGetVRs, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([2], x[1], [])) +push!(nets, net) + +# 7. NeuralFMU with additional setter +net = Chain(x -> fmu(;x=x, u_refs=setVRs, u=[1.1], dx_refs=:all), + x -> c1(x), + Dense(numStates, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([1], x[1], [])) +push!(nets, net) + +# 8. NeuralFMU with additional setter and getter +net = Chain(x -> fmu(;x=x, u_refs=setVRs, u=[1.1], y_refs=getVRs, y=y, dx_refs=:all), + x -> c1(x), + Dense(numStates+numGetVRs, 8, tanh; init=init), + Dense(8, 16, tanh; init=init), + Dense(16, 1, identity; init=init), + x -> c2([2], x[1], [])) +push!(nets, net) + +# 9. an empty NeuralFMU (this does only make sense for debugging) +net = Chain(x -> fmu(x=x, dx_refs=:all)) +push!(nets, net) + +for i in 1:length(nets) + @testset "Net setup #$i" begin + global nets, problem, lastLoss, iterCB + + optim = GradientDescent(;alphaguess=1e-6) # BFGS() + solver = Tsit5() + + net = nets[i] + problem = ME_NeuralFMU(fmu, net, (t_start, t_stop), solver) + @test problem != nothing + + # train it ... + p_net = Flux.params(problem) + @test length(p_net) == 1 + + solutionBefore = problem(x0; p=p_net[1], showProgress=true, saveat=tData) + if solutionBefore.success + @test length(solutionBefore.states.t) == length(tData) + @test solutionBefore.states.t[1] == t_start + @test solutionBefore.states.t[end] == t_stop + end + + iterCB = 0 + lastLoss = losssum(p_net[1]) + @info "Start-Loss for net #$i: $lastLoss" + FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), gradient=GRADIENT) + + # check results + solutionAfter = problem(x0; p=p_net[1], showProgress=true, saveat=tData) + if solutionAfter.success + @test length(solutionAfter.states.t) == length(tData) + @test solutionAfter.states.t[1] == t_start + @test solutionAfter.states.t[end] == t_stop + end + end +end + +@test length(fmu.components) <= 1 + +fmiUnload(fmu) diff --git a/test/runtests.jl b/test/runtests.jl index c2cf6874..797fd547 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,7 @@ using FMIFlux.FMIImport: fmi2StringToValueReference, fmi2ValueReference, prepare using FMIFlux.FMIImport: FMU2_EXECUTION_CONFIGURATION_NO_FREEING, FMU2_EXECUTION_CONFIGURATION_NO_RESET, FMU2_EXECUTION_CONFIGURATION_RESET using FMIFlux: fmi2GetSolutionState, fmi2GetSolutionValue, fmi2GetSolutionTime -exportingToolsWindows = [("Dymola", "2022x")] +exportingToolsWindows = [("Dymola", "2022x")] # [("ModelicaReferenceFMUs", "0.0.25")] exportingToolsLinux = [("Dymola", "2022x")] # number of training steps to perform @@ -33,7 +33,7 @@ function runtests(exportingTool) @info "Testing FMUs exported from $(EXPORTINGTOOL) ($(EXPORTINGVERSION))" @testset "Testing FMUs exported from $(EXPORTINGTOOL) ($(EXPORTINGVERSION))" begin - for _GRADIENT ∈ (:ReverseDiff, :ForwardDiff) + for _GRADIENT ∈ (:ReverseDiff, :ForwardDiff) # , :FiniteDiff) global GRADIENT = _GRADIENT @info "Gradient: $(GRADIENT)" @@ -64,10 +64,10 @@ function runtests(exportingTool) include("train_modes.jl") end - # @info "Multi-threading (multi_threading.jl)" - # @testset "Multi-threading" begin - # include("multi_threading.jl") - # end + @info "Multi-threading (multi_threading.jl)" + @testset "Multi-threading" begin + include("multi_threading.jl") + end @info "CS-NeuralFMU (hybrid_CS.jl)" @testset "CS-NeuralFMU" begin @@ -83,9 +83,19 @@ function runtests(exportingTool) @testset "Batching" begin include("batching.jl") end + + @info "Optimizers from Optim.jl (optim.jl)" + @testset "Optim" begin + include("optim.jl") + end end end + @info "Solution Gradients (solution_gradients.jl)" + @testset "Solution Gradients" begin + include("solution_gradients.jl") + end + @info "Benchmark: Supported sensitivities (supported_sensitivities.jl)" @testset "Benchmark: Supported sensitivities " begin include("supported_sensitivities.jl") diff --git a/test/solution_gradients.jl b/test/solution_gradients.jl new file mode 100644 index 00000000..72f9f7d2 --- /dev/null +++ b/test/solution_gradients.jl @@ -0,0 +1,277 @@ +# +# Copyright (c) 2021 Tobias Thummerer, Lars Mikelsons +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +using FMI +using Flux +using DifferentialEquations +using FMIFlux, FMIZoo, Test +import FMIFlux.FMISensitivity.SciMLSensitivity.SciMLBase: RightRootFind, LeftRootFind +import FMIFlux: unsense +using FMIFlux.FMISensitivity.SciMLSensitivity.ForwardDiff, FMIFlux.FMISensitivity.SciMLSensitivity.ReverseDiff, FMIFlux.FMISensitivity.SciMLSensitivity.FiniteDiff, FMIFlux.FMISensitivity.SciMLSensitivity.Zygote + +import Random +Random.seed!(5678); + +global solution = nothing +global events = 0 + +ENERGY_LOSS = 0.7 +RADIUS = 0.0 +GRAVITY = 9.81 +DBL_MIN = 2.2250738585072013830902327173324040642192159804623318306e-308 + +NUMEVENTS = 4 + +t_start = 0.0 +t_step = 0.05 +t_stop = 2.0 +tData = t_start:t_step:t_stop +posData = ones(length(tData)) + +x0 = [1.0, 0.0] +numStates = 2 +solver = Tsit5() + +# setup BouncingBallODE +function fx(x) + return [x[2], -GRAVITY] +end + +function fx_bb(dx, x, p, t) + dx[:] = re_bb(p)(x) + return nothing +end + +net_bb = Chain(#Dense([1.0 0.0; 0.0 1.0], [0.0, 0.0], identity), + fx, + Dense([1.0 0.0; 0.0 1.0], [0.0, 0.0], identity)) +p_net_bb, re_bb = Flux.destructure(net_bb) + +ff = ODEFunction{true}(fx_bb) +prob_bb = ODEProblem{true}(ff, x0, (t_start, t_stop), p_net_bb) + +function condition(out, x, t, integrator) + out[1] = x[1]-RADIUS + #out[2] = x[1]-RADIUS +end + +function affect_right!(integrator, idx) + s_new = RADIUS + DBL_MIN + v_new = -integrator.u[2]*ENERGY_LOSS + u_new = [s_new, v_new] + + global events + events += 1 + # @info "[$(events)] New state at $(integrator.t) is $(u_new)" + + integrator.u .= u_new +end +function affect_left!(integrator, idx) + s_new = integrator.u[1] + v_new = -integrator.u[2]*ENERGY_LOSS + u_new = [s_new, v_new] + + global events + events += 1 + # @info "[$(events)] New state at $(integrator.t) is $(u_new)" + + integrator.u .= u_new +end + +rightCb = VectorContinuousCallback(condition, + affect_right!, + 1; + rootfind=RightRootFind, save_positions=(false, false)) +leftCb = VectorContinuousCallback(condition, + affect_left!, + 1; + rootfind=LeftRootFind, save_positions=(false, false)) + +# load FMU for NeuralFMU +fmu = fmiLoad("BouncingBall", "ModelicaReferenceFMUs", "0.0.25"; type=:ME) +fmu.handleEventIndicators = nothing + +net = Chain(#Dense([1.0 0.0; 0.0 1.0], [0.0, 0.0], identity), + x -> fmu(;x=x, dx_refs=:all), + Dense([1.0 0.0; 0.0 1.0], [0.0; 0.0], identity)) + +prob = ME_NeuralFMU(fmu, net, (t_start, t_stop)) +prob.modifiedState = false + +# ANNs + +function losssum(p; sensealg=nothing) + global posData + posNet = mysolve(p; sensealg=sensealg) + + return Flux.Losses.mae(posNet, posData) +end + +function losssum_bb(p; sensealg=nothing, root=:Right) + global posData + posNet = mysolve_bb(p; sensealg=sensealg, root=root) + + return Flux.Losses.mae(posNet, posData) +end + +function mysolve(p; sensealg=nothing) + global solution, events # write + global prob, x0, posData, solver # read-only + events = 0 + + solution = prob(x0; p=p, solver=solver, saveat=tData) + + return collect(u[1] for u in solution.states.u) +end + +function mysolve_bb(p; sensealg=nothing, root=:Right) + global solution # write + global prob_bb, solver, events # read + events = 0 + + callback = nothing + if root == :Right + callback = CallbackSet(rightCb) + else + callback = CallbackSet(leftCb) + end + solution = solve(prob_bb; p=p, alg=solver, saveat=tData, callback=callback, sensealg=sensealg) + + if !isa(solution, AbstractArray) + if solution.retcode != FMI.ReturnCode.Success + @error "Solution failed!" + return Inf + end + + return collect(u[1] for u in solution.u) + else + return solution[1,:] # collect(solution[:,i] for i in 1:size(solution)[2]) + end +end + +p_net = Flux.params(prob)[1] + +using FMIFlux.FMISensitivity.SciMLSensitivity +sensealg = ReverseDiffAdjoint() + +c = nothing +c, x0 = FMIFlux.prepareSolveFMU(prob.fmu, c, fmi2TypeModelExchange, nothing, nothing, nothing, nothing, nothing, prob.parameters, prob.tspan[1], prob.tspan[end], nothing; x0=prob.x0, handleEvents=FMIFlux.handleEvents, cleanup=true) + +### START CHECK CONDITIONS + +function condition_bb_check(x) + buffer = similar(x, 1) + condition(buffer, x, t_start, nothing) + return buffer +end +function condition_nfmu_check(x) + buffer = similar(x, 1) + FMIFlux.condition!(prob, FMIFlux.getComponent(prob), buffer, x, t_start, nothing, [UInt32(1)]) + return buffer +end +jac_fwd1 = ForwardDiff.jacobian(condition_bb_check, x0) +jac_fwd2 = ForwardDiff.jacobian(condition_nfmu_check, x0) + +jac_rwd1 = ReverseDiff.jacobian(condition_bb_check, x0) +jac_rwd2 = ReverseDiff.jacobian(condition_nfmu_check, x0) + +jac_fin1 = FiniteDiff.finite_difference_jacobian(condition_bb_check, x0) +jac_fin2 = FiniteDiff.finite_difference_jacobian(condition_nfmu_check, x0) + +atol = 1e-8 +@test isapprox(jac_fin1, jac_fwd1; atol=atol) +@test isapprox(jac_fin1, jac_rwd1; atol=atol) +@test isapprox(jac_fin2, jac_fwd2; atol=atol) +@test isapprox(jac_fin2, jac_rwd2; atol=atol) + +### START CHECK AFFECT + +# import SciMLSensitivity: u_modified! +# import FMI: fmi2SimulateME +# function u_modified!(::NamedTuple, ::Any) +# return nothing +# end +# function affect_bb_check(x) +# integrator = (t=t_start, u=x) +# affect_right!(integrator, 1) +# return integrator.u +# end +# function affect_nfmu_check(x) +# integrator = (t=t_start, u=x, opts=(internalnorm=(a,b)->1.0,) ) +# #FMIFlux.affectFMU!(prob, FMIFlux.getComponent(prob), integrator, 1) +# integrator.u[1] = DBL_MIN +# integrator.u[2] = -0.7 * integrator.u[2] +# return integrator.u +# end +# t_first_event_time = 0.451523640985728 +# x_first_event_right = [2.2250738585072014e-308, 3.1006128426489954] + +# jac_con1 = ForwardDiff.jacobian(affect_bb_check, x0) +# jac_con2 = ForwardDiff.jacobian(affect_nfmu_check, x0) + +# jac_con1 = ReverseDiff.jacobian(affect_bb_check, x0) +# jac_con2 = ReverseDiff.jacobian(affect_nfmu_check, x0) + +### + +# Solution (plain) +losssum(p_net; sensealg=sensealg) +#@test length(solution.events) == NUMEVENTS + +losssum_bb(p_net_bb; sensealg=sensealg) +@test events == NUMEVENTS + +# Solution FWD +grad_fwd1 = ForwardDiff.gradient(p -> losssum(p; sensealg=sensealg), p_net) +#@test length(solution.events) == NUMEVENTS + +grad_fwd2 = ForwardDiff.gradient(p -> losssum_bb(p; sensealg=sensealg), p_net_bb) +@test events == NUMEVENTS + +# Solution ReverseDiff +grad_rwd1 = ReverseDiff.gradient(p -> losssum(p; sensealg=sensealg), p_net) +#@test length(solution.events) == NUMEVENTS + +grad_rwd2 = ReverseDiff.gradient(p -> losssum_bb(p; sensealg=sensealg), p_net_bb) +@test events == NUMEVENTS + +# Ground Truth +grad_fin1 = FiniteDiff.finite_difference_gradient(p -> losssum(p; sensealg=sensealg), p_net, Val{:central}; absstep=1e-8) +grad_fin2 = FiniteDiff.finite_difference_gradient(p -> losssum_bb(p; sensealg=sensealg), p_net_bb, Val{:central}; absstep=1e-8) + +atol = 1e-5 +@test isapprox(grad_fin1, grad_fwd1; atol=atol) +@test isapprox(grad_fin2, grad_fwd2; atol=atol) + +@test isapprox(grad_fin1, grad_rwd1; atol=atol) +@test isapprox(grad_fin2, grad_rwd2; atol=atol) + +# Jacobian Test + +jac_fwd1 = ForwardDiff.jacobian(p -> mysolve_bb(p; sensealg=sensealg), p_net) +jac_fwd2 = ForwardDiff.jacobian(p -> mysolve(p; sensealg=sensealg), p_net) + +jac_rwd1 = ReverseDiff.jacobian(p -> mysolve_bb(p; sensealg=sensealg), p_net) +jac_rwd2 = ReverseDiff.jacobian(p -> mysolve(p; sensealg=sensealg), p_net) + +# [TODO] why this?! +jac_rwd1[2:end,:] = jac_rwd1[2:end,:] .- jac_rwd1[1:end-1,:] +jac_rwd2[2:end,:] = jac_rwd2[2:end,:] .- jac_rwd2[1:end-1,:] + +jac_fin1 = FiniteDiff.finite_difference_jacobian(p -> mysolve_bb(p; sensealg=sensealg), p_net) +jac_fin2 = FiniteDiff.finite_difference_jacobian(p -> mysolve(p; sensealg=sensealg), p_net) + +### + +atol = 1e-4 +@test isapprox(jac_fin1, jac_fwd1; atol=atol) +@test isapprox(jac_fin1, jac_rwd1; atol=atol) + +@test isapprox(jac_fin2, jac_fwd2; atol=atol) +@test isapprox(jac_fin2, jac_rwd2; atol=atol) + +### + +fmiUnload(fmu) diff --git a/test/supported_sensitivities.jl b/test/supported_sensitivities.jl index 47b09b3d..a1b70ae6 100644 --- a/test/supported_sensitivities.jl +++ b/test/supported_sensitivities.jl @@ -5,15 +5,18 @@ using FMI using Flux +using FMI.DifferentialEquations import Random Random.seed!(5678); +# boundaries t_start = 0.0 t_step = 0.1 t_stop = 3.0 tData = t_start:t_step:t_stop -velData = sin.(tData) +posData = ones(length(tData)) +tspan = (t_start, t_stop) # load FMU for NeuralFMU fmu = fmiLoad("SpringFrictionPendulum1D", EXPORTINGTOOL, EXPORTINGVERSION; type=:ME) @@ -40,18 +43,20 @@ net = Chain(x -> c1(x), # loss function for training function losssum(p) global nfmu, x0, posData - solution = nfmu(x0; p=p) + solution = nfmu(x0; p=p, saveat=tData) if !solution.success return Inf end - velNet = fmi2GetSolutionState(solution, 2; isIndex=true) + posNet = fmi2GetSolutionState(solution, 1; isIndex=true) - return FMIFlux.Losses.mse(velNet, velData) + return FMIFlux.Losses.mse(posNet, posData) end -nfmu = ME_NeuralFMU(fmu, net, (t_start, t_stop); saveat=tData) +solver = Tsit5() +nfmu = ME_NeuralFMU(fmu, net, (t_start, t_stop), solver; saveat=tData) +nfmu.modifiedState = false FMIFlux.checkSensalgs!(losssum, nfmu) diff --git a/test/train_modes.jl b/test/train_modes.jl index 3ca2162f..b359750a 100644 --- a/test/train_modes.jl +++ b/test/train_modes.jl @@ -33,7 +33,7 @@ velData = fmi2GetSolutionValue(realSimData, "mass.v") # loss function for training function losssum(p) global problem, x0, posData - solution = problem(x0; p=p) + solution = problem(x0; p=p, saveat=tData) if !solution.success return Inf @@ -94,19 +94,20 @@ for handleEvents in [true, false] c1 = CacheLayer() c2 = CacheRetrieveLayer(c1) - net = Chain(states -> myFMU(;x=states), + + net = Chain(states -> myFMU(;x=states, dx_refs=:all), dx -> c1(dx), - Dense(numStates, 16, tanh; init=Flux.identity_init), - Dense(16, 1, identity; init=Flux.identity_init), + Dense(numStates, 16, tanh), + Dense(16, 1, identity), dx -> c2([1], dx[1], []) ) optim = Adam(1e-8) solver = Tsit5() - problem = ME_NeuralFMU(myFMU, net, (t_start, t_stop), solver; saveat=tData) + problem = ME_NeuralFMU(myFMU, net, (t_start, t_stop), solver) @test problem != nothing - solutionBefore = problem(x0) + solutionBefore = problem(x0; saveat=tData) if solutionBefore.success @test length(solutionBefore.states.t) == length(tData) @test solutionBefore.states.t[1] == t_start @@ -124,7 +125,7 @@ for handleEvents in [true, false] FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), gradient=GRADIENT) # check results - solutionAfter = problem(x0) + solutionAfter = problem(x0; saveat=tData) if solutionAfter.success @test length(solutionAfter.states.t) == length(tData) @test solutionAfter.states.t[1] == t_start