Skip to content

Commit

Permalink
Merge pull request #64 from JuliaGNI/increase_test_coverage_for_layers
Browse files Browse the repository at this point in the history
Increase test coverage for layers
  • Loading branch information
michakraus authored Sep 1, 2023
2 parents 52a6606 + 7601d55 commit c2f79b5
Show file tree
Hide file tree
Showing 15 changed files with 60 additions and 11 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 2 additions & 0 deletions src/layers/attention_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ function upper_triangular_asymmetrize(A::AbstractArray{T, 3}) where T
output
end

### the functions starting from here are needed for computing the derivative.

@kernel function assign_upper_triangular_kernel!(output, input)
i,j,k = @index(Global, NTuple)
if i < j
Expand Down
8 changes: 0 additions & 8 deletions src/layers/psd_like_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,4 @@ function (::PSDLayer{M, N})(x::AbstractArray{T, 3}, ps::NamedTuple) where {M, N,

q, p = assign_q_and_p(x, dim÷2)
N > M ? vcat(mat_tensor_mul(ps.weight,q), mat_tensor_mul(ps.weight,p)) : vcat(mat_tensor_mul(ps.weight', q), mat_tensor_mul(ps.weight', p))
end

function retraction(::PSDLayer{N, M, Geodesic}, B::NamedTuple{(:weight,),Tuple{AT}}) where {N, M, AT<:StiefelLieAlgHorMatrix}
geodesic(B)
end

function retraction(::PSDLayer{N, M, Cayley}, B::NamedTuple{(:weight,),Tuple{AT}}) where {N, M, AT<:StiefelLieAlgHorMatrix}
cayley(B)
end
2 changes: 1 addition & 1 deletion src/optimizers/manifold_related/retractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ function cayley(B::StiefelLieAlgHorMatrix{T}) where T
E + hcat(vcat(T(.25)*B.A, T(.5)*B.B), vcat(T(0.5)*unit, zero(B.B)))*(exponent \ vcat(unit, T(0.5)*B.A))
)
)
end
end
23 changes: 23 additions & 0 deletions test/attention_layer/apply_multi_head_attention.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using GeometricMachineLearning, Test
using NNlib: softmax
using LinearAlgebra: I

function compare_attention_to_mha(N, batch_size=10, T=Float32)
model₁ = MultiHeadAttention(N, 1, add_connection=false)
model₂ = Attention(N, softmax, add_connection=false)
model₃ = MultiHeadAttention(N, 1, add_connection=true)
model₄ = Attention(N, softmax, add_connection=true)

ps₂ = initialparameters(CPU(), T, model₂)
ps₁ = (PQ=(head_1=ps₂.PQ,), PK=(head_1=ps₂.PK,), PV=(head_1=typeof(ps₂.PK)(I(N)),))

mat = rand(T, N, N)
ten = rand(T, N, N, batch_size)
@test isapprox(model₁(mat, ps₁), model₂(mat, ps₂))
@test isapprox(model₁(ten, ps₁), model₂(ten, ps₂))

@test isapprox(model₃(mat, ps₁), model₄(mat, ps₂))
@test isapprox(model₃(ten, ps₁), model₄(ten, ps₂))
end

compare_attention_to_mha(10)
11 changes: 10 additions & 1 deletion test/attention_layer/attention_setup.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
using GeometricMachineLearning, Test
using GeometricMachineLearning: upper_triangular_asymmetrize
using GeometricMachineLearning: orthonormal_activation
using LinearAlgebra: det

function attention_tests(N, T=Float32)
model₁ = Attention(N, Stiefel=false)
model₂ = Attention(N, Stiefel=true)
model₃ = Attention(N, orthonormal_activation, Stiefel=false)
# same as model₁, but with the add connection
model₄ = Attention(N, Stiefel=false, add_connection=true)

ps₁ = initialparameters(CPU(), T, model₁)
ps₂ = initialparameters(CPU(), T, model₂)
ps₃ = initialparameters(CPU(), T, model₃)
@test typeof(ps₂.PQ) <: StiefelManifold
@test typeof(ps₂.PQ) <: StiefelManifold
@test typeof(ps₂.PK) <: StiefelManifold

A = randn(N, N)
det₁ = det(A)
det₂ = det(model₁(A, ps₁))
det₃ = det(model₂(A, ps₂))
det₄ = det(model₃(A, ps₃))
@test isapprox(det₁, det₂)
@test isapprox(det₂, det₃)
@test isapprox(det₃, det₄)

@test isapprox(model₁(A, ps₁), model₄(A, ps₁)-A)

A = reshape(rand(SkewSymMatrix{T}, N), N, N, 1)
@test isapprox(A, upper_triangular_asymmetrize(A))
Expand Down
3 changes: 2 additions & 1 deletion test/custom_ad_rules/kernel_pullbacks.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using GeometricMachineLearning: tensor_mat_mul, mat_tensor_mul, tensor_tensor_mul, tensor_transpose_tensor_mul, assign_q_and_p, tensor_transpose, assign_matrix, assign_tensor, assign_output_estimate, vec_tensor_mul
using GeometricMachineLearning: tensor_mat_mul, mat_tensor_mul, tensor_tensor_mul, tensor_transpose_tensor_mul, assign_q_and_p, tensor_transpose, assign_matrix, assign_tensor, assign_output_estimate, vec_tensor_mul, upper_triangular_asymmetrize
using ChainRulesTestUtils
using Printf

Expand All @@ -18,6 +18,7 @@ function main(first_dim, second_dim, third_dim, third_tensor_dim)
test_rrule(assign_tensor, rand(first_dim, second_dim), third_tensor_dim, 1)
test_rrule(assign_output_estimate, rand(first_dim, second_dim, third_tensor_dim), 1)
test_rrule(vec_tensor_mul, rand(first_dim), rand(first_dim, second_dim, third_tensor_dim))
test_rrule(upper_triangular_asymmetrize, rand(first_dim, first_dim, third_tensor_dim))
end

const dim_range = 10
Expand Down
20 changes: 20 additions & 0 deletions test/layers/sympnet_upscaling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using GeometricMachineLearning, Test, Zygote

function test_symplecticity(N=4, N2=20, T=Float32)
model = Chain(PSDLayer(N, N2), GradientQ(N2, 2*N2, tanh), GradientP(N2, 2*N2, tanh), PSDLayer(N2, N))
ps = initialparameters(CPU(), T, model)
x = rand(T, N)
ten = rand(T, N, N)
# the first and last PSD layer need to have the same weight! (else they map to a different symplectic potential)
ps[4].weight.A = ps[1].weight.A
jacobian_matrix = Zygote.jacobian(x -> model(x, ps), x)[1]
𝕁 = SymplecticPotential(N÷2)
@test isapprox(jacobian_matrix'*𝕁*jacobian_matrix, 𝕁, rtol=0.1)
@test isapprox(model(ten, ps)[:,1], model(ten[:,1], ps))
end

for N=2:2:20
for N2=2*N:2:4*N
test_symplecticity()
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using SafeTestsets
@safetestset "Manifolds (Grassmann): " begin include("manifolds/grassmann_manifold.jl") end
@safetestset "Gradient Layer " begin include("layers/gradient_layer_tests.jl") end
@safetestset "SympNet Layers " begin include("layers/sympnet_layers_test.jl") end
@safetestset "Test symplecticity of upscaling layer " begin include("layers/sympnet_layers_test.jl") end
@safetestset "Hamiltonian Neural Network " begin include("hamiltonian_neural_network_tests.jl") end
@safetestset "Manifold Neural Network Layers " begin include("layers/manifold_layers.jl") end
@safetestset "Custom AD rules for kernels " begin include("custom_ad_rules/kernel_pullbacks.jl") end
Expand All @@ -19,6 +20,7 @@ using SafeTestsets
@safetestset "Transformer Networks #6 " begin include("transformer_related/transformer_gradient.jl") end
@safetestset "Transformer Networks #7 " begin include("transformer_related/transformer_optimizer.jl") end
@safetestset "Attention layer #1 " begin include("attention_layer/attention_setup.jl") end
@safetestset "(MultiHead)Attention " begin include("attention_layer/apply_multi_head_attention.jl") end
@safetestset "Optimizer #1 " begin include("optimizers/utils/global_sections.jl") end
@safetestset "Optimizer #2 " begin include("optimizers/utils/optimization_step.jl") end
@safetestset "Optimizer #3 " begin include("optimizers/utils/modified_exponential.jl") end
Expand Down

0 comments on commit c2f79b5

Please sign in to comment.