-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into volume_preserving_feedforward
- Loading branch information
Showing
54 changed files
with
758 additions
and
246 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# pre-push git hook that runs all tests before pushing | ||
|
||
red='\033[0;31m' | ||
green='\033[0;32m' | ||
no_color='\033[0m' | ||
|
||
reponame=$(basename `git rev-parse --show-toplevel`) | ||
|
||
|
||
echo "\nRunning pre-push hook\n" | ||
echo "Testing $reponame" | ||
julia --project=@. -e "using Pkg; Pkg.test(\"GeometricMachineLearning\")" | ||
|
||
if [[ $? -ne 0 ]]; then | ||
echo "\n${red}ERROR - Tests must pass before push!\n${no_color}" | ||
exit 1 | ||
fi | ||
|
||
echo "\n${green}Git hook was SUCCESSFUL!${no_color}\n" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
@kernel function cpu_inverse_kernel!(B, A) | ||
k = @index(Global) | ||
@views A_temp = A[:, :, k] | ||
@views B_temp = B[:, :, k] | ||
|
||
B_temp .= inv(A_temp) | ||
|
||
nothing | ||
end | ||
|
||
function cpu_inverse(A::AbstractArray) | ||
B = zero(A) | ||
backend = KernelAbstractions.get_backend(A) | ||
|
||
cpu_inverse! = cpu_inverse_kernel!(backend) | ||
cpu_inverse!(B, A, ndrange=size(A, 3)) | ||
|
||
B | ||
end | ||
|
||
@kernel function cpu_inverse_pullback_kernel!(dA, A, dB) | ||
k = @index(Global) | ||
@views A_temp = A[:, :, k] | ||
@views dA_temp = dA[:, :, k] | ||
@views dB_temp = dB[:, :, k] | ||
|
||
copy!(dA_temp, Zygote.pullback(inv, A_temp)[2](dB_temp)[1]) | ||
|
||
nothing | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(cpu_inverse), A::AbstractArray) | ||
B = cpu_inverse(A) | ||
|
||
function cpu_inverse_pullback(dB::AbstractArray) | ||
dA = zero(dB) | ||
backend = KernelAbstractions.get_backend(dB) | ||
|
||
cpu_inverse_pullback! = cpu_inverse_pullback_kernel!(backend) | ||
cpu_inverse_pullback!(dA, A, dB, ndrange=size(dB, 3)) | ||
|
||
return NoTangent(), dA | ||
end | ||
|
||
return B, cpu_inverse_pullback | ||
end | ||
|
||
function cpu_tensor_cayley(A::AbstractArray) | ||
one_A = init_output(A) | ||
|
||
tensor_tensor_mul(one_A - A, cpu_inverse(one_A + A)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
@kernel function inv22_kernel!(ˍ₋out::AT, A::AT) where {T, AT<:AbstractArray{T, 3}} | ||
k = @index(Global) | ||
begin | ||
@inbounds begin | ||
ˍ₋out[1, 1, k] = (/)((getindex)(A, 2, 2, k), (+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)))) | ||
ˍ₋out[2, 1, k] = (/)((*)(-1, (getindex)(A, 2, 1, k)), (+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)))) | ||
ˍ₋out[1, 2, k] = (/)((*)(-1, (getindex)(A, 1, 2, k)), (+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)))) | ||
ˍ₋out[2, 2, k] = (/)((getindex)(A, 1, 1, k), (+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)))) | ||
nothing | ||
end | ||
end | ||
end | ||
|
||
function tensor_inverse2(A::AbstractArray{T, 3}) where T | ||
out = similar(A) | ||
|
||
tensor_inverse2!(out, A) | ||
|
||
out | ||
end | ||
|
||
function tensor_inverse2!(out::AbstractArray{T, 3}, A::AbstractArray{T, 3}) where T | ||
@assert size(A, 1) == size(A, 2) == 2 | ||
@assert size(A) == size(out) | ||
|
||
backend = get_backend(out) | ||
inv22! = inv22_kernel!(backend) | ||
|
||
inv22!(out, A, ndrange = size(A, 3)) | ||
|
||
nothing | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(tensor_inverse2), A::AT) where {T, AT<:AbstractArray{T, 3}} | ||
out = tensor_inverse2(A) | ||
|
||
function tensor_inverse_pullback(out_diff::AT) | ||
|
||
NoTangent(), - tensor_transpose_tensor_mul(out, tensor_tensor_mul(out_diff, tensor_transpose(out))) | ||
end | ||
out, tensor_inverse_pullback | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
@kernel function inv33_kernel!(ˍ₋out::AT, A::AT) where {T, AT<:AbstractArray{T, 3}} | ||
k = @index(Global) | ||
begin | ||
@inbounds begin | ||
ˍ₋out[1, 1, k] = (/)((+)((*)((getindex)(A, 2, 2, k), (getindex)(A, 3, 3, k)), (*)((*)(-1, (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k)))) | ||
ˍ₋out[2, 1, k] = (/)((+)((*)((*)(-1, (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k)), (*)((getindex)(A, 2, 3, k), (getindex)(A, 3, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k)))) | ||
ˍ₋out[3, 1, k] = (/)((+)((*)((getindex)(A, 2, 1, k), (getindex)(A, 3, 2, k)), (*)((*)(-1, (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k)))) | ||
ˍ₋out[1, 2, k] = (/)((+)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 3, 3, k)), (*)((getindex)(A, 1, 3, k), (getindex)(A, 3, 2, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k)))) | ||
ˍ₋out[2, 2, k] = (/)((+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 3, 3, k)), (*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 3, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k)))) | ||
ˍ₋out[3, 2, k] = (/)((+)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 3, 2, k)), (*)((getindex)(A, 1, 2, k), (getindex)(A, 3, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k)))) | ||
ˍ₋out[1, 3, k] = (/)((+)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k)))) | ||
ˍ₋out[2, 3, k] = (/)((+)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k)))) | ||
ˍ₋out[3, 3, k] = (/)((+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k)))) | ||
nothing | ||
end | ||
end | ||
end | ||
|
||
function tensor_inverse3(A::AbstractArray{T, 3}) where T | ||
out = similar(A) | ||
|
||
tensor_inverse3!(out, A) | ||
|
||
out | ||
end | ||
|
||
function tensor_inverse3!(out::AbstractArray{T, 3}, A::AbstractArray{T, 3}) where T | ||
@assert size(A, 1) == size(A, 2) == 3 | ||
@assert size(A) == size(out) | ||
|
||
backend = get_backend(out) | ||
inv33! = inv33_kernel!(backend) | ||
|
||
inv33!(out, A, ndrange = size(A, 3)) | ||
|
||
nothing | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(tensor_inverse3), A::AT) where {T, AT<:AbstractArray{T, 3}} | ||
out = tensor_inverse3(A) | ||
|
||
function tensor_inverse_pullback(out_diff::AT) | ||
|
||
NoTangent(), - tensor_transpose_tensor_mul(out, tensor_tensor_mul(out_diff, tensor_transpose(out))) | ||
end | ||
out, tensor_inverse_pullback | ||
end |
Oops, something went wrong.