Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase test coverage for layers #64

Merged
merged 14 commits into from
Sep 1, 2023
Merged
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 2 additions & 0 deletions src/layers/attention_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ function upper_triangular_asymmetrize(A::AbstractArray{T, 3}) where T
output
end

### the functions starting from here are needed for computing the derivative.

@kernel function assign_upper_triangular_kernel!(output, input)
i,j,k = @index(Global, NTuple)
if i < j
Expand Down
8 changes: 0 additions & 8 deletions src/layers/psd_like_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,4 @@ function (::PSDLayer{M, N})(x::AbstractArray{T, 3}, ps::NamedTuple) where {M, N,

q, p = assign_q_and_p(x, dim÷2)
N > M ? vcat(mat_tensor_mul(ps.weight,q), mat_tensor_mul(ps.weight,p)) : vcat(mat_tensor_mul(ps.weight', q), mat_tensor_mul(ps.weight', p))
end

function retraction(::PSDLayer{N, M, Geodesic}, B::NamedTuple{(:weight,),Tuple{AT}}) where {N, M, AT<:StiefelLieAlgHorMatrix}
geodesic(B)
end

function retraction(::PSDLayer{N, M, Cayley}, B::NamedTuple{(:weight,),Tuple{AT}}) where {N, M, AT<:StiefelLieAlgHorMatrix}
cayley(B)
end
2 changes: 1 addition & 1 deletion src/optimizers/manifold_related/retractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ function cayley(B::StiefelLieAlgHorMatrix{T}) where T
E + hcat(vcat(T(.25)*B.A, T(.5)*B.B), vcat(T(0.5)*unit, zero(B.B)))*(exponent \ vcat(unit, T(0.5)*B.A))
)
)
end
end
23 changes: 23 additions & 0 deletions test/attention_layer/apply_multi_head_attention.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using GeometricMachineLearning, Test
using NNlib: softmax
using LinearAlgebra: I

function compare_attention_to_mha(N, batch_size=10, T=Float32)
model₁ = MultiHeadAttention(N, 1, add_connection=false)
model₂ = Attention(N, softmax, add_connection=false)
model₃ = MultiHeadAttention(N, 1, add_connection=true)
model₄ = Attention(N, softmax, add_connection=true)

ps₂ = initialparameters(CPU(), T, model₂)
ps₁ = (PQ=(head_1=ps₂.PQ,), PK=(head_1=ps₂.PK,), PV=(head_1=typeof(ps₂.PK)(I(N)),))

mat = rand(T, N, N)
ten = rand(T, N, N, batch_size)
@test isapprox(model₁(mat, ps₁), model₂(mat, ps₂))
@test isapprox(model₁(ten, ps₁), model₂(ten, ps₂))

@test isapprox(model₃(mat, ps₁), model₄(mat, ps₂))
@test isapprox(model₃(ten, ps₁), model₄(ten, ps₂))
end

compare_attention_to_mha(10)
11 changes: 10 additions & 1 deletion test/attention_layer/attention_setup.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
using GeometricMachineLearning, Test
using GeometricMachineLearning: upper_triangular_asymmetrize
using GeometricMachineLearning: orthonormal_activation
using LinearAlgebra: det

function attention_tests(N, T=Float32)
model₁ = Attention(N, Stiefel=false)
model₂ = Attention(N, Stiefel=true)
model₃ = Attention(N, orthonormal_activation, Stiefel=false)
# same as model₁, but with the add connection
model₄ = Attention(N, Stiefel=false, add_connection=true)

ps₁ = initialparameters(CPU(), T, model₁)
ps₂ = initialparameters(CPU(), T, model₂)
ps₃ = initialparameters(CPU(), T, model₃)
@test typeof(ps₂.PQ) <: StiefelManifold
@test typeof(ps₂.PQ) <: StiefelManifold
@test typeof(ps₂.PK) <: StiefelManifold

A = randn(N, N)
det₁ = det(A)
det₂ = det(model₁(A, ps₁))
det₃ = det(model₂(A, ps₂))
det₄ = det(model₃(A, ps₃))
@test isapprox(det₁, det₂)
@test isapprox(det₂, det₃)
@test isapprox(det₃, det₄)

@test isapprox(model₁(A, ps₁), model₄(A, ps₁)-A)

A = reshape(rand(SkewSymMatrix{T}, N), N, N, 1)
@test isapprox(A, upper_triangular_asymmetrize(A))
Expand Down
3 changes: 2 additions & 1 deletion test/custom_ad_rules/kernel_pullbacks.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using GeometricMachineLearning: tensor_mat_mul, mat_tensor_mul, tensor_tensor_mul, tensor_transpose_tensor_mul, assign_q_and_p, tensor_transpose, assign_matrix, assign_tensor, assign_output_estimate, vec_tensor_mul
using GeometricMachineLearning: tensor_mat_mul, mat_tensor_mul, tensor_tensor_mul, tensor_transpose_tensor_mul, assign_q_and_p, tensor_transpose, assign_matrix, assign_tensor, assign_output_estimate, vec_tensor_mul, upper_triangular_asymmetrize
using ChainRulesTestUtils
using Printf

Expand All @@ -18,6 +18,7 @@ function main(first_dim, second_dim, third_dim, third_tensor_dim)
test_rrule(assign_tensor, rand(first_dim, second_dim), third_tensor_dim, 1)
test_rrule(assign_output_estimate, rand(first_dim, second_dim, third_tensor_dim), 1)
test_rrule(vec_tensor_mul, rand(first_dim), rand(first_dim, second_dim, third_tensor_dim))
test_rrule(upper_triangular_asymmetrize, rand(first_dim, first_dim, third_tensor_dim))
end

const dim_range = 10
Expand Down
20 changes: 20 additions & 0 deletions test/layers/sympnet_upscaling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using GeometricMachineLearning, Test, Zygote

function test_symplecticity(N=4, N2=20, T=Float32)
model = Chain(PSDLayer(N, N2), GradientQ(N2, 2*N2, tanh), GradientP(N2, 2*N2, tanh), PSDLayer(N2, N))
ps = initialparameters(CPU(), T, model)
x = rand(T, N)
ten = rand(T, N, N)
# the first and last PSD layer need to have the same weight! (else they map to a different symplectic potential)
ps[4].weight.A = ps[1].weight.A
jacobian_matrix = Zygote.jacobian(x -> model(x, ps), x)[1]
𝕁 = SymplecticPotential(N÷2)
@test isapprox(jacobian_matrix'*𝕁*jacobian_matrix, 𝕁, rtol=0.1)
@test isapprox(model(ten, ps)[:,1], model(ten[:,1], ps))
end

for N=2:2:20
for N2=2*N:2:4*N
test_symplecticity()
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using SafeTestsets
@safetestset "Manifolds (Grassmann): " begin include("manifolds/grassmann_manifold.jl") end
@safetestset "Gradient Layer " begin include("layers/gradient_layer_tests.jl") end
@safetestset "SympNet Layers " begin include("layers/sympnet_layers_test.jl") end
@safetestset "Test symplecticity of upscaling layer " begin include("layers/sympnet_layers_test.jl") end
@safetestset "Hamiltonian Neural Network " begin include("hamiltonian_neural_network_tests.jl") end
@safetestset "Manifold Neural Network Layers " begin include("layers/manifold_layers.jl") end
@safetestset "Custom AD rules for kernels " begin include("custom_ad_rules/kernel_pullbacks.jl") end
Expand All @@ -19,6 +20,7 @@ using SafeTestsets
@safetestset "Transformer Networks #6 " begin include("transformer_related/transformer_gradient.jl") end
@safetestset "Transformer Networks #7 " begin include("transformer_related/transformer_optimizer.jl") end
@safetestset "Attention layer #1 " begin include("attention_layer/attention_setup.jl") end
@safetestset "(MultiHead)Attention " begin include("attention_layer/apply_multi_head_attention.jl") end
@safetestset "Optimizer #1 " begin include("optimizers/utils/global_sections.jl") end
@safetestset "Optimizer #2 " begin include("optimizers/utils/optimization_step.jl") end
@safetestset "Optimizer #3 " begin include("optimizers/utils/modified_exponential.jl") end
Expand Down