Skip to content

Commit

Permalink
Merge branch 'main' into volume_preserving_feedforward
Browse files Browse the repository at this point in the history
  • Loading branch information
benedict-96 committed Apr 9, 2024
2 parents 9723b51 + 908bc21 commit c2d7eb6
Show file tree
Hide file tree
Showing 54 changed files with 758 additions and 246 deletions.
19 changes: 19 additions & 0 deletions .githooks/pre-push
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# pre-push git hook that runs all tests before pushing

red='\033[0;31m'
green='\033[0;32m'
no_color='\033[0m'

reponame=$(basename `git rev-parse --show-toplevel`)


echo "\nRunning pre-push hook\n"
echo "Testing $reponame"
julia --project=@. -e "using Pkg; Pkg.test(\"GeometricMachineLearning\")"

if [[ $? -ne 0 ]]; then
echo "\n${red}ERROR - Tests must pass before push!\n${no_color}"
exit 1
fi

echo "\n${green}Git hook was SUCCESSFUL!${no_color}\n"
12 changes: 8 additions & 4 deletions .github/workflows/Documenter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,24 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- run: |
- name: Install packages for generating tikz images
run: |
sudo apt-get install imagemagick
sudo apt-get install poppler-utils
sudo apt-get install texlive-xetex
sudo apt-get install texlive-science
make all -C docs/src/tikz
- name: Make tikz images
run: make all -C docs/src/tikz
- uses: julia-actions/setup-julia@latest
with:
version: '1'
- run: |
- name: Install BrenierTwoFluid package
run: |
cd docs
make install_brenier_two_fluid test_docs
cd ..
julia --project=docs docs/make.jl html_output
- name: Make docs (call julia documenter)
run: julia --project=docs docs/make.jl html_output
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-docdeploy@v1
env:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Latex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
LatexDocs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
os: [ubuntu-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractNeuralNetworks = "0.1"
AbstractNeuralNetworks = "0.2"
BandedMatrices = "0.17, 1"
ChainRules = "1"
ChainRulesCore = "1"
Expand Down
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ More examples like this can be found in the docs.
- Brantner B. Generalizing Adam To Manifolds For Efficiently Training Transformers. arXiv preprint arXiv:2305.16901, 2023.
- Brantner B., Kraus M. Symplectic Autoencoders for Model Reduction of Hamiltonian Systems. arXiv preprint arXiv:2312.10004, 2023.
- Brantner B., Romemont G., Kraus M., Li Z. Structure-Preserving Transformers for Learning Parametrized Hamiltonian Systems. arXiv preprint arXiv:2312.11166, 2023.


## Development

We are using git hooks, e.g., to enforce that all tests pass before pushing.
In order to activate these hooks, the following command must be executed once:
```
git config core.hooksPath .githooks
```
6 changes: 2 additions & 4 deletions docs/src/tikz/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ convert_with_pdftocairo: $(MYDIR)/*.pdf
done

png:
pdftocairo -png -r 500 -transp -singlefile logo_with_name.pdf logo_with_name
pdftocairo -png -r 500 -transp -singlefile logo_with_name_dark.pdf logo_with_name_dark
pdftocairo -png -r $(res1) -transp -singlefile adam_optimizer.pdf adam_optimizer
pdftocairo -png -r $(res1) -transp -singlefile adam_optimizer_dark.pdf adam_optimizer_dark
pdftocairo -png -r $(res1) -transp -singlefile general_optimization.pdf general_optimization
Expand All @@ -57,10 +59,6 @@ png:
pdftocairo -png -r $(res1) -transp -singlefile tensor_dark.pdf tensor_dark
pdftocairo -png -r $(res4) -transp -singlefile tensor_sampling.pdf tensor_sampling
pdftocairo -png -r $(res4) -transp -singlefile tensor_sampling_dark.pdf tensor_sampling_dark
pdftocairo -png -r $(res4) -transp -singlefile skew_sym_visualization.pdf skew_sym_visualization
pdftocairo -png -r $(res4) -transp -singlefile skew_sym_visualization_dark.pdf skew_sym_visualization_dark
pdftocairo -png -r $(res1) -transp -singlefile vp_feedforward.pdf vp_feedforward
pdftocairo -png -r $(res1) -transp -singlefile vp_feedforward_dark.pdf vp_feedforward_dark

logo:
cp logo_with_name.png ../assets/logo.png
Expand Down
13 changes: 10 additions & 3 deletions src/GeometricMachineLearning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@ module GeometricMachineLearning
include("kernels/mat_tensor_mul.jl")
include("kernels/tensor_transpose.jl")
include("kernels/exponentials/tensor_exponential.jl")
include("kernels/inverses/inverse_kernel.jl")
include("kernels/inverses/cpu_inverse.jl")
include("kernels/inverses/inverse_2x2.jl")
include("kernels/inverses/inverse_3x3.jl")
include("kernels/inverses/inverse_4x4.jl")
include("kernels/inverses/inverse_5x5.jl")
include("kernels/inverses/tensor_cayley.jl")
include("kernels/inverses/tensor_mat_skew_sym_assign.jl")
include("kernels/vec_tensor_mul.jl")

include("kernels/kernel_ad_routines/assign_q_and_p.jl")
Expand All @@ -83,6 +89,7 @@ module GeometricMachineLearning
include("kernels/kernel_ad_routines/tensor_tensor_mul.jl")
include("kernels/kernel_ad_routines/tensor_transpose_tensor_mul.jl")
include("kernels/kernel_ad_routines/tensor_transpose.jl")
include("kernels/kernel_ad_routines/tensor_mat_skew_sym_assign.jl")
include("kernels/kernel_ad_routines/vec_tensor_mul.jl")
# export tensor_mat_mul

Expand Down Expand Up @@ -137,7 +144,7 @@ module GeometricMachineLearning
include("layers/stiefel_layer.jl")
include("layers/grassmann_layer.jl")
include("layers/multi_head_attention.jl")
include("layers/attention_layer.jl")
include("layers/volume_preserving_attention.jl")
include("layers/transformer.jl")
include("layers/psd_like_layer.jl")
include("layers/classification.jl")
Expand All @@ -147,7 +154,7 @@ module GeometricMachineLearning
export StiefelLayer, GrassmannLayer, ManifoldLayer
export PSDLayer
export MultiHeadAttention
export Attention
export VolumePreservingAttention
export ResNet
export Transformer
export Classification
Expand Down
52 changes: 52 additions & 0 deletions src/kernels/inverses/cpu_inverse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
@kernel function cpu_inverse_kernel!(B, A)
k = @index(Global)
@views A_temp = A[:, :, k]
@views B_temp = B[:, :, k]

B_temp .= inv(A_temp)

nothing
end

function cpu_inverse(A::AbstractArray)
B = zero(A)
backend = KernelAbstractions.get_backend(A)

cpu_inverse! = cpu_inverse_kernel!(backend)
cpu_inverse!(B, A, ndrange=size(A, 3))

B
end

@kernel function cpu_inverse_pullback_kernel!(dA, A, dB)
k = @index(Global)
@views A_temp = A[:, :, k]
@views dA_temp = dA[:, :, k]
@views dB_temp = dB[:, :, k]

copy!(dA_temp, Zygote.pullback(inv, A_temp)[2](dB_temp)[1])

nothing
end

function ChainRulesCore.rrule(::typeof(cpu_inverse), A::AbstractArray)
B = cpu_inverse(A)

function cpu_inverse_pullback(dB::AbstractArray)
dA = zero(dB)
backend = KernelAbstractions.get_backend(dB)

cpu_inverse_pullback! = cpu_inverse_pullback_kernel!(backend)
cpu_inverse_pullback!(dA, A, dB, ndrange=size(dB, 3))

return NoTangent(), dA
end

return B, cpu_inverse_pullback
end

function cpu_tensor_cayley(A::AbstractArray)
one_A = init_output(A)

tensor_tensor_mul(one_A - A, cpu_inverse(one_A + A))
end
42 changes: 42 additions & 0 deletions src/kernels/inverses/inverse_2x2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
@kernel function inv22_kernel!(ˍ₋out::AT, A::AT) where {T, AT<:AbstractArray{T, 3}}
k = @index(Global)
begin
@inbounds begin
ˍ₋out[1, 1, k] = (/)((getindex)(A, 2, 2, k), (+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k))))
ˍ₋out[2, 1, k] = (/)((*)(-1, (getindex)(A, 2, 1, k)), (+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k))))
ˍ₋out[1, 2, k] = (/)((*)(-1, (getindex)(A, 1, 2, k)), (+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k))))
ˍ₋out[2, 2, k] = (/)((getindex)(A, 1, 1, k), (+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k))))
nothing
end
end
end

function tensor_inverse2(A::AbstractArray{T, 3}) where T
out = similar(A)

tensor_inverse2!(out, A)

out
end

function tensor_inverse2!(out::AbstractArray{T, 3}, A::AbstractArray{T, 3}) where T
@assert size(A, 1) == size(A, 2) == 2
@assert size(A) == size(out)

backend = get_backend(out)
inv22! = inv22_kernel!(backend)

inv22!(out, A, ndrange = size(A, 3))

nothing
end

function ChainRulesCore.rrule(::typeof(tensor_inverse2), A::AT) where {T, AT<:AbstractArray{T, 3}}
out = tensor_inverse2(A)

function tensor_inverse_pullback(out_diff::AT)

NoTangent(), - tensor_transpose_tensor_mul(out, tensor_tensor_mul(out_diff, tensor_transpose(out)))
end
out, tensor_inverse_pullback
end
47 changes: 47 additions & 0 deletions src/kernels/inverses/inverse_3x3.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
@kernel function inv33_kernel!(ˍ₋out::AT, A::AT) where {T, AT<:AbstractArray{T, 3}}
k = @index(Global)
begin
@inbounds begin
ˍ₋out[1, 1, k] = (/)((+)((*)((getindex)(A, 2, 2, k), (getindex)(A, 3, 3, k)), (*)((*)(-1, (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k))))
ˍ₋out[2, 1, k] = (/)((+)((*)((*)(-1, (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k)), (*)((getindex)(A, 2, 3, k), (getindex)(A, 3, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k))))
ˍ₋out[3, 1, k] = (/)((+)((*)((getindex)(A, 2, 1, k), (getindex)(A, 3, 2, k)), (*)((*)(-1, (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k))))
ˍ₋out[1, 2, k] = (/)((+)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 3, 3, k)), (*)((getindex)(A, 1, 3, k), (getindex)(A, 3, 2, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k))))
ˍ₋out[2, 2, k] = (/)((+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 3, 3, k)), (*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 3, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k))))
ˍ₋out[3, 2, k] = (/)((+)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 3, 2, k)), (*)((getindex)(A, 1, 2, k), (getindex)(A, 3, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k))))
ˍ₋out[1, 3, k] = (/)((+)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k))))
ˍ₋out[2, 3, k] = (/)((+)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k))))
ˍ₋out[3, 3, k] = (/)((+)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k))), (+)((+)((+)((+)((+)((*)((*)((getindex)(A, 1, 1, k), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 3, k)), (*)((*)((*)(-1, (getindex)(A, 1, 1, k)), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 2, k)), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 3, k))), (*)((*)((getindex)(A, 1, 2, k), (getindex)(A, 2, 3, k)), (getindex)(A, 3, 1, k))), (*)((*)((getindex)(A, 1, 3, k), (getindex)(A, 2, 1, k)), (getindex)(A, 3, 2, k))), (*)((*)((*)(-1, (getindex)(A, 1, 3, k)), (getindex)(A, 2, 2, k)), (getindex)(A, 3, 1, k))))
nothing
end
end
end

function tensor_inverse3(A::AbstractArray{T, 3}) where T
out = similar(A)

tensor_inverse3!(out, A)

out
end

function tensor_inverse3!(out::AbstractArray{T, 3}, A::AbstractArray{T, 3}) where T
@assert size(A, 1) == size(A, 2) == 3
@assert size(A) == size(out)

backend = get_backend(out)
inv33! = inv33_kernel!(backend)

inv33!(out, A, ndrange = size(A, 3))

nothing
end

function ChainRulesCore.rrule(::typeof(tensor_inverse3), A::AT) where {T, AT<:AbstractArray{T, 3}}
out = tensor_inverse3(A)

function tensor_inverse_pullback(out_diff::AT)

NoTangent(), - tensor_transpose_tensor_mul(out, tensor_tensor_mul(out_diff, tensor_transpose(out)))
end
out, tensor_inverse_pullback
end
Loading

0 comments on commit c2d7eb6

Please sign in to comment.