-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #64 from JuliaGNI/increase_test_coverage_for_layers
Increase test coverage for layers
- Loading branch information
Showing
15 changed files
with
60 additions
and
11 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
using GeometricMachineLearning, Test | ||
using NNlib: softmax | ||
using LinearAlgebra: I | ||
|
||
function compare_attention_to_mha(N, batch_size=10, T=Float32) | ||
model₁ = MultiHeadAttention(N, 1, add_connection=false) | ||
model₂ = Attention(N, softmax, add_connection=false) | ||
model₃ = MultiHeadAttention(N, 1, add_connection=true) | ||
model₄ = Attention(N, softmax, add_connection=true) | ||
|
||
ps₂ = initialparameters(CPU(), T, model₂) | ||
ps₁ = (PQ=(head_1=ps₂.PQ,), PK=(head_1=ps₂.PK,), PV=(head_1=typeof(ps₂.PK)(I(N)),)) | ||
|
||
mat = rand(T, N, N) | ||
ten = rand(T, N, N, batch_size) | ||
@test isapprox(model₁(mat, ps₁), model₂(mat, ps₂)) | ||
@test isapprox(model₁(ten, ps₁), model₂(ten, ps₂)) | ||
|
||
@test isapprox(model₃(mat, ps₁), model₄(mat, ps₂)) | ||
@test isapprox(model₃(ten, ps₁), model₄(ten, ps₂)) | ||
end | ||
|
||
compare_attention_to_mha(10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
using GeometricMachineLearning, Test, Zygote | ||
|
||
function test_symplecticity(N=4, N2=20, T=Float32) | ||
model = Chain(PSDLayer(N, N2), GradientQ(N2, 2*N2, tanh), GradientP(N2, 2*N2, tanh), PSDLayer(N2, N)) | ||
ps = initialparameters(CPU(), T, model) | ||
x = rand(T, N) | ||
ten = rand(T, N, N) | ||
# the first and last PSD layer need to have the same weight! (else they map to a different symplectic potential) | ||
ps[4].weight.A = ps[1].weight.A | ||
jacobian_matrix = Zygote.jacobian(x -> model(x, ps), x)[1] | ||
𝕁 = SymplecticPotential(N÷2) | ||
@test isapprox(jacobian_matrix'*𝕁*jacobian_matrix, 𝕁, rtol=0.1) | ||
@test isapprox(model(ten, ps)[:,1], model(ten[:,1], ps)) | ||
end | ||
|
||
for N=2:2:20 | ||
for N2=2*N:2:4*N | ||
test_symplecticity() | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters