Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NeuralTangentKernel Loss #506

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/pinns_pde_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,61 @@ SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss(reweight_every; pde_max_optim
end


"""

A way of adaptively reweighing the components of the loss function by using the values of the Jacobian of the
NTK predictions at the current point (infinite width assumption).

* 'u_pred' : the predictions of the Network on the training data at current state
* 'r_pred' : the predictions of the Network on the boundary conditions at current state
* 'kernel_size' : the size of the kernel used to compute the Jacobian for the NTK matrix
* `reweight_every`: how often to reweight the PDE and BC loss functions, measured in iterations. reweighting is cheap since it re-uses the value of loss functions generated during the main optimisation loop,
* `pde_max_optimiser`: a Flux.Optimise.AbstractOptimiser that is used internally to maximize the weights of the PDE loss functions,
* `bc_max_optimiser`: a Flux.Optimise.AbstractOptimiser that is used internally to maximize the weights of the BC loss functions,
* `pde_loss_weights`: either a scalar (which will be broadcast) or vector the size of the number of PDE equations, which describes the initial weight the respective PDE loss has in the full loss sum,
* `bc_loss_weights`: either a scalar (which will be broadcast) or vector the size of the number of BC equations, which describes the initial weight the respective BC loss has in the full loss sum,
* `additional_loss_weights`: a scalar which describes the weight the additional loss function has in the full loss sum, this is currently not adaptive and will be constant with this adaptive loss,

Adapted from paper
When and Why PINNs Fail to Train: A Neural Tangent kernel perspective
https://arxiv.org/pdf/2007.14527.pdf

"""


struct NeuralTangentKernelAdaptiveLoss <: AbstractAdaptiveLoss
pde_loss_weights::Vector{T}
bc_loss_weights::Vector{T}
additional_loss_weights::Vector{T}
pde_max_optimiser::PDE_OPT
bc_max_optimiser::BC_OPT

SciMLBase.@add_kwonly function NeuralTangentKernelAdaptiveLoss{T, PDE_OPT, BC_OPT}(reweight_every; pde_max_optimiser=Flux.ADAM(1e-4), bc_max_optimiser=Flux.ADAM(0.5),
u_pred::Vector{T}, r_pred::Vector{T}, kernel_size::Int, pde_loss_weights=1, bc_loss_weights=1, additional_loss_weights=1) where {T <: Real, PDE_OPT <: Flux.Optimise.AbstractOptimiser, BC_OPT <: Flux.Optimise.AbstractOptimiser}
Jr, Ju = compute_jacobian(r_pred, bc_loss_weights), compute_jacobian(u_pred, pde_loss_weights)
Kr, Ku = compute_ntk(Jr, kernel_size , Jr, kernel_size), compute_ntk(Ju, kernel_size, Ju, kernel_size)
lambda_r, lambda_u = (trace(Kr) + trace(Ku))/trace(Kr), (trace(Kr) + trace(Ku))/trace(Ku)
new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser), convert(BC_OPT, bc_max_optimiser),
(lambda_r * vectorify(pde_loss_weights, T)), (lambda_u * vectorify(bc_loss_weights, T)), vectorify(additional_loss_weights, T))
end
end

SciMLBase.@add_kwonly function NeuralTangentKernelAdaptiveLoss(reweight_every; pde_max_optimiser=Flux.ADAM(1e-4), bc_max_optimiser=Flux.ADAM(0.5),
u_pred::Vector{T}, r_pred::Vector{T}, kernel_size::Int, pde_loss_weights=1, bc_loss_weights=1, additional_loss_weights=1) where {T <: Real, PDE_OPT <: Flux.Optimise.AbstractOptimiser, BC_OPT <: Flux.Optimise.AbstractOptimiser}
NeuralTangentKernelAdaptiveLoss{Float64, typeof(pde_max_optimiser), typeof(bc_max_optimiser)}(
reweight_every; pde_max_optimiser=pde_max_optimiser, bc_max_optimiser=bc_max_optimiser,
u_pred=u_pred, r_pred=r_pred, kernel_size=kernel_size, pde_loss_weights=pde_loss_weights, bc_loss_weights=bc_loss_weights, additional_loss_weights=additional_loss_weights)
end


function compute_ntk{T}(J1_list::Vector{T}, D1::Int, J2_list::Vector{T}, D2::Int) where {T <: Real}

Ker = zeros(D1,D2)
for k in 0:(size(J1_list)[1])
Ker += reshape(J1_list[k], (in,-1))* transpose(reshape(J2_list[k], (in,-1)))
return Ker
end

"""
Create dictionary: variable => unique number for variable

Expand Down