diff --git a/.githooks/pre-push b/.githooks/pre-push index 0e3dc747d..634ca0d8b 100755 --- a/.githooks/pre-push +++ b/.githooks/pre-push @@ -1,4 +1,4 @@ -# pre-push git hook that runs all tests before pushing +# pre-push git hook that runs all tests before pushin red='\033[0;31m' green='\033[0;32m' diff --git a/Project.toml b/Project.toml index 897325709..712c961d5 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ UpdateJulia = "770da0de-323d-4d28-9202-0e205c1e0aff" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AbstractNeuralNetworks = "0.4" +AbstractNeuralNetworks = "0.5" BandedMatrices = "1" ChainRules = "1" ChainRulesCore = "1" diff --git a/docs/Project.toml b/docs/Project.toml index 49c8aad52..a1f358b64 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,6 @@ [deps] Bibliography = "f1be7e48-bf82-45af-a471-ae754a193061" +BrenierTwoFluid = "698bc5df-bacc-4e45-9592-41ae9e406d75" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" diff --git a/docs/src/tutorials/adjusting_the_loss_function.md b/docs/src/tutorials/adjusting_the_loss_function.md index 359c832ed..a691696f5 100644 --- a/docs/src/tutorials/adjusting_the_loss_function.md +++ b/docs/src/tutorials/adjusting_the_loss_function.md @@ -14,6 +14,7 @@ We again consider training a SympNet on the data coming from a harmonic oscillat ```@example change_loss using GeometricMachineLearning # hide +using GeometricMachineLearning: params # hide using GeometricIntegrators: integrate, ImplicitMidpoint # hide using GeometricProblems.HarmonicOscillator: hodeproblem import Random # hide @@ -45,7 +46,7 @@ function network_parameter_norm(params::NeuralNetworkParameters) sum([network_parameter_norm(params[key]) for key in keys(params)]) end -network_parameter_norm(nn.params) +network_parameter_norm(params(nn)) ``` We now implement a custom loss such that: @@ -80,7 +81,7 @@ print(loss_array[end]) We see that the norm of the parameters is lower: ```@example change_loss -network_parameter_norm(nn_custom.params) +network_parameter_norm(params(nn_custom)) ``` We can also compare the solutions of the two networks: diff --git a/docs/src/tutorials/grassmann_layer.md b/docs/src/tutorials/grassmann_layer.md index 9a29d9f74..627d23351 100644 --- a/docs/src/tutorials/grassmann_layer.md +++ b/docs/src/tutorials/grassmann_layer.md @@ -96,6 +96,7 @@ Before we can use the Wasserstein distance however to train the neural network w ```@example rosenbrock using GeometricMachineLearning # hide +using GeometricMachineLearning: params # hide using Zygote, BrenierTwoFluid using LinearAlgebra: norm # hide import Random # hide @@ -111,7 +112,7 @@ nothing # hide We then *lift* the neural network parameters via [`GlobalSection`](@ref). ```@example rosenbrock -λY = GlobalSection(nn.params) +λY = GlobalSection(params(nn)) nothing # hide ``` @@ -280,9 +281,9 @@ CairoMakie.activate!() # hide const training_steps = 80 loss_array = zeros(training_steps) for i in 1:training_steps - val, dp = compute_gradient(nn.params) + val, dp = compute_gradient(params(nn)) loss_array[i] = val - optimization_step!(optimizer, λY, nn.params, dp.params) + optimization_step!(optimizer, λY, params(nn), dp.params) end ``` diff --git a/docs/src/tutorials/mnist/mnist_tutorial.md b/docs/src/tutorials/mnist/mnist_tutorial.md index 0dcddf077..b2111e5aa 100644 --- a/docs/src/tutorials/mnist/mnist_tutorial.md +++ b/docs/src/tutorials/mnist/mnist_tutorial.md @@ -86,7 +86,7 @@ Here we have chosen a [`ClassificationTransformer`](@ref), i.e. a composition of We now have to initialize the neural network weights. This is done with the constructor for `NeuralNetwork`: ```@example mnist -backend = GeometricMachineLearning.get_backend(dl) +backend = GeometricMachineLearning.networkbackend(dl) T = eltype(dl) nn1 = NeuralNetwork(model1, backend, T) nn2 = NeuralNetwork(model2, backend, T) diff --git a/docs/src/tutorials/volume_preserving_attention.md b/docs/src/tutorials/volume_preserving_attention.md index 19fa1feeb..58d5ed9cc 100644 --- a/docs/src/tutorials/volume_preserving_attention.md +++ b/docs/src/tutorials/volume_preserving_attention.md @@ -6,7 +6,7 @@ In here we demonstrate the differences between the two approaches for computing ```@example volume_preserving_attention using GeometricMachineLearning # hide -using GeometricMachineLearning: FeedForwardLoss, TransformerLoss # hide +using GeometricMachineLearning: FeedForwardLoss, TransformerLoss, params # hide import Random # hide Random.seed!(123) # hide @@ -199,15 +199,15 @@ initial_condition = dl.input[:, 1:seq_length, 2] function make_networks_neural_network_integrators(nn_skew, nn_arb, nn_comp) nn_skew = NeuralNetwork(GeometricMachineLearning.DummyTransformer(seq_length), nn_skew.model, - nn_skew.params, + params(nn_skew), CPU()) nn_arb = NeuralNetwork(GeometricMachineLearning.DummyTransformer(seq_length), nn_arb.model, - nn_arb.params, + params(nn_arb), CPU()) nn_comp = NeuralNetwork(GeometricMachineLearning.DummyNNIntegrator(), nn_comp.model, - nn_comp.params, + params(nn_comp), CPU()) nn_skew, nn_arb, nn_comp diff --git a/src/GeometricMachineLearning.jl b/src/GeometricMachineLearning.jl index 75ecdab86..31e294bd8 100644 --- a/src/GeometricMachineLearning.jl +++ b/src/GeometricMachineLearning.jl @@ -29,6 +29,7 @@ module GeometricMachineLearning import AbstractNeuralNetworks: GlorotUniform import AbstractNeuralNetworks: params, architecture, model, dim import AbstractNeuralNetworks: AbstractPullback, NetworkLoss, _compute_loss + import AbstractNeuralNetworks: networkbackend # export params, architetcure, model export dim import NNlib: σ, sigmoid, softmax diff --git a/src/architectures/autoencoder.jl b/src/architectures/autoencoder.jl index 91de1b68c..6f55fac70 100644 --- a/src/architectures/autoencoder.jl +++ b/src/architectures/autoencoder.jl @@ -82,11 +82,10 @@ We show how to make an encoder from a custom architecture: ```jldoctest using GeometricMachineLearning -using GeometricMachineLearning: UnknownEncoder +using GeometricMachineLearning: UnknownEncoder, params model = Chain(Dense(5, 3, tanh; use_bias = false), Dense(3, 2, identity; use_bias = false)) -params = NeuralNetworkParameters(initialparameters(model)) -nn = NeuralNetwork(UnknownEncoder(5, 2, 2), model, params, CPU()) +nn = NeuralNetwork(UnknownEncoder(5, 2, 2), model, params(NeuralNetwork(model)), CPU()) typeof(nn) <: NeuralNetwork{<:GeometricMachineLearning.Encoder} @@ -173,7 +172,7 @@ end function encoder_parameters(nn::NeuralNetwork{<:AutoEncoder}) n_encoder_layers = length(encoder_model(nn.architecture).layers) keys = Tuple(Symbol.(["L$(i)" for i in 1:n_encoder_layers])) - NeuralNetworkParameters(NamedTuple{keys}(Tuple([nn.params[key] for key in keys]))) + NeuralNetworkParameters(NamedTuple{keys}(Tuple([params(nn)[key] for key in keys]))) end # """ @@ -183,13 +182,13 @@ end # """ function decoder_parameters(nn::NeuralNetwork{<:AutoEncoder}) n_decoder_layers = length(decoder_model(nn.architecture).layers) - all_keys = keys(nn.params) - # "old keys" are the ones describing the correct parameters in nn.params + all_keys = keys(params(nn)) + # "old keys" are the ones describing the correct parameters in params(nn) keys_old = Tuple(Symbol.(["L$(i)" for i in (length(all_keys) - (n_decoder_layers - 1)):length(all_keys)])) n_keys = length(keys_old) # "new keys" are the ones describing the keys in the new NamedTuple keys_new = Tuple(Symbol.(["L$(i)" for i in 1:n_keys])) - NeuralNetworkParameters(NamedTuple{keys_new}(Tuple([nn.params[key] for key in keys_old]))) + NeuralNetworkParameters(NamedTuple{keys_new}(Tuple([params(nn)[key] for key in keys_old]))) end function Chain(arch::AutoEncoder) @@ -205,14 +204,14 @@ function encoder(nn::NeuralNetwork{<:AutoEncoder}) NeuralNetwork( UnknownEncoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), encoder_model(nn.architecture), encoder_parameters(nn), - get_backend(nn)) + networkbackend(nn)) end function _encoder(nn::NeuralNetwork, full_dim::Integer, reduced_dim::Integer) NeuralNetwork( UnknownEncoder(full_dim, reduced_dim, length(nn.model.layers)), nn.model, - nn.params, - get_backend(nn)) + params(nn), + networkbackend(nn)) end function input_dimension(::AbstractExplicitLayer{M, N}) where {M, N} @@ -242,11 +241,11 @@ end Obtain the *decoder* from an [`AutoEncoder`](@ref) neural network. """ function decoder(nn::NeuralNetwork{<:AutoEncoder}) - NeuralNetwork(UnknownDecoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), decoder_model(nn.architecture), decoder_parameters(nn), get_backend(nn)) + NeuralNetwork(UnknownDecoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), decoder_model(nn.architecture), decoder_parameters(nn), networkbackend(nn)) end function _decoder(nn::NeuralNetwork, full_dim::Integer, reduced_dim::Integer) - NeuralNetwork(UnknownDecoder(full_dim, reduced_dim, length(nn.model.layers)), nn.model, nn.params, get_backend(nn)) + NeuralNetwork(UnknownDecoder(full_dim, reduced_dim, length(nn.model.layers)), nn.model, params(nn), networkbackend(nn)) end @doc raw""" @@ -263,9 +262,9 @@ function decoder(nn::NeuralNetwork) end function encoder(nn::NeuralNetwork{<:SymplecticCompression}) - NeuralNetwork(UnknownSymplecticEncoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), encoder_model(nn.architecture), encoder_parameters(nn), get_backend(nn)) + NeuralNetwork(UnknownSymplecticEncoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), encoder_model(nn.architecture), encoder_parameters(nn), networkbackend(nn)) end function decoder(nn::NeuralNetwork{<:SymplecticCompression}) - NeuralNetwork(UnknownSymplecticDecoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), decoder_model(nn.architecture), decoder_parameters(nn), get_backend(nn)) + NeuralNetwork(UnknownSymplecticDecoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), decoder_model(nn.architecture), decoder_parameters(nn), networkbackend(nn)) end \ No newline at end of file diff --git a/src/architectures/hamiltonian_neural_network.jl b/src/architectures/hamiltonian_neural_network.jl index f81518e2f..c26d51307 100644 --- a/src/architectures/hamiltonian_neural_network.jl +++ b/src/architectures/hamiltonian_neural_network.jl @@ -25,10 +25,10 @@ function Chain(nn::HamiltonianNeuralNetwork) end # gradient of the Hamiltonian Neural Network -gradient(nn::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, x, params = nn.params) = Zygote.gradient(ξ -> sum(nn(ξ, params)), x)[1] +gradient(nn::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, x, params = params(nn)) = Zygote.gradient(ξ -> sum(nn(ξ, params)), x)[1] # vector field of the Hamiltonian Neural Network -function vectorfield(nn::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, x, params = nn.params) +function vectorfield(nn::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, x, params = params(nn)) n_dim = length(x)÷2 I = Diagonal(ones(n_dim)) Z = zeros(n_dim,n_dim) diff --git a/src/architectures/lagrangian_neural_network.jl b/src/architectures/lagrangian_neural_network.jl index 23787247f..d527742ef 100644 --- a/src/architectures/lagrangian_neural_network.jl +++ b/src/architectures/lagrangian_neural_network.jl @@ -29,14 +29,14 @@ end # gradient of the Lagrangian Neural Network -∇L(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, x, params = nn.params) = Zygote.gradient(x->sum(nn(x, params)), x)[1] +∇L(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, x, params = params(nn)) = Zygote.gradient(x->sum(nn(x, params)), x)[1] # hessian of the Lagrangian Neural Network -∇∇L(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, q, q̇, params = nn.params) = Zygote.hessian(x->sum(nn(x, params)),[q...,q̇...]) +∇∇L(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, q, q̇, params = params(nn)) = Zygote.hessian(x->sum(nn(x, params)),[q...,q̇...]) -∇q̇∇q̇L(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, q, q̇, params = nn.params) = ∇∇L(nn, q, q̇, params)[(1+length(q̇)):end,(1+length(q̇)):end] +∇q̇∇q̇L(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, q, q̇, params = params(nn)) = ∇∇L(nn, q, q̇, params)[(1+length(q̇)):end,(1+length(q̇)):end] -∇q∇q̇L(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, q, q̇, params = nn.params) = ∇∇L(nn, q, q̇, params)[1:length(q),(1+length(q̇)):end] +∇q∇q̇L(nn::NeuralNetwork{<:LagrangianNeuralNetwork}, q, q̇, params = params(nn)) = ∇∇L(nn, q, q̇, params)[1:length(q),(1+length(q̇)):end] diff --git a/src/architectures/neural_network_integrator.jl b/src/architectures/neural_network_integrator.jl index 131cb65a9..02376eac8 100644 --- a/src/architectures/neural_network_integrator.jl +++ b/src/architectures/neural_network_integrator.jl @@ -22,7 +22,7 @@ abstract type NeuralNetworkIntegrator <: Architecture end function Base.iterate(nn::NeuralNetwork{<:NeuralNetworkIntegrator}, ics::AT; n_points = 100) where {T, AT<:AbstractVector{T}} n_dim = length(ics) - backend = KernelAbstractions.get_backend(ics) + backend = networkbackend(ics) # Array to store the predictions valuation = KernelAbstractions.allocate(backend, T, n_dim, n_points) @@ -97,7 +97,7 @@ The number of integration steps that should be performed. function Base.iterate(nn::NeuralNetwork{<:NeuralNetworkIntegrator}, ics::BT; n_points = 100) where {T, AT<:AbstractVector{T}, BT<:NamedTuple{(:q, :p), Tuple{AT, AT}}} n_dim2 = length(ics.q) - backend = KernelAbstractions.get_backend(ics.q) + backend = networkbackend(ics.q) # Array to store the predictions valuation = (q = KernelAbstractions.allocate(backend, T, n_dim2, n_points), p = KernelAbstractions.allocate(backend, T, n_dim2, n_points)) diff --git a/src/architectures/psd.jl b/src/architectures/psd.jl index d47b51345..5268916fa 100644 --- a/src/architectures/psd.jl +++ b/src/architectures/psd.jl @@ -57,8 +57,8 @@ function solve!(nn::NeuralNetwork{<:PSDArch}, input::AbstractMatrix) half_of_dimension_in_big_space = nn.architecture.full_dim ÷ 2 @views input_qp = hcat(input[1 : half_of_dimension_in_big_space, :], input[(half_of_dimension_in_big_space + 1) : end, :]) U_term = svd(input_qp).U - @views nn.params[1].weight.A .= U_term[:, 1 : nn.architecture.reduced_dim ÷ 2] - @views nn.params[2].weight.A .= U_term[:, 1 : nn.architecture.reduced_dim ÷ 2] + @views params(nn)[1].weight.A .= U_term[:, 1 : nn.architecture.reduced_dim ÷ 2] + @views params(nn)[2].weight.A .= U_term[:, 1 : nn.architecture.reduced_dim ÷ 2] AutoEncoderLoss()(nn, input) end diff --git a/src/architectures/symplectic_autoencoder.jl b/src/architectures/symplectic_autoencoder.jl index 79a58bd6d..e15d70973 100644 --- a/src/architectures/symplectic_autoencoder.jl +++ b/src/architectures/symplectic_autoencoder.jl @@ -151,10 +151,10 @@ end function encoder(nn::NeuralNetwork{<:SymplecticAutoencoder}) arch = NonLinearSymplecticEncoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_layers, nn.architecture.n_encoder_blocks, nn.architecture.sympnet_upscale, nn.architecture.activation) - NeuralNetwork(arch, encoder_model(nn.architecture), encoder_parameters(nn), get_backend(nn)) + NeuralNetwork(arch, encoder_model(nn.architecture), encoder_parameters(nn), networkbackend(nn)) end function decoder(nn::NeuralNetwork{<:SymplecticAutoencoder}) arch = NonLinearSymplecticDecoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_decoder_layers, nn.architecture.n_decoder_blocks, nn.architecture.sympnet_upscale, nn.architecture.activation) - NeuralNetwork(arch, decoder_model(nn.architecture), decoder_parameters(nn), get_backend(nn)) + NeuralNetwork(arch, decoder_model(nn.architecture), decoder_parameters(nn), networkbackend(nn)) end \ No newline at end of file diff --git a/src/architectures/transformer_integrator.jl b/src/architectures/transformer_integrator.jl index 2239c78b4..e45e84c40 100644 --- a/src/architectures/transformer_integrator.jl +++ b/src/architectures/transformer_integrator.jl @@ -48,7 +48,7 @@ function Base.iterate(nn::NeuralNetwork{<:TransformerIntegrator}, ics::NamedTupl seq_length = typeof(nn.architecture) <: StandardTransformerIntegrator ? size(ics.q, 2) : nn.architecture.seq_length n_dim = size(ics.q, 1) - backend = KernelAbstractions.get_backend(ics.q) + backend = networkbackend(ics.q) n_iterations = Int(ceil((n_points - seq_length) / prediction_window)) # Array to store the predictions @@ -84,7 +84,7 @@ function Base.iterate(nn::NeuralNetwork{<:TransformerIntegrator}, ics::AT; n_poi end n_dim = size(ics, 1) - backend = KernelAbstractions.get_backend(ics) + backend = networkbackend(ics) n_iterations = Int(ceil((n_points - seq_length) / prediction_window)) # Array to store the predictions diff --git a/src/arrays/grassmann_lie_algebra_horizontal.jl b/src/arrays/grassmann_lie_algebra_horizontal.jl index 3b4dd4266..cb7414c63 100644 --- a/src/arrays/grassmann_lie_algebra_horizontal.jl +++ b/src/arrays/grassmann_lie_algebra_horizontal.jl @@ -80,7 +80,7 @@ end Base.parent(A::GrassmannLieAlgHorMatrix) = (A.B, ) Base.size(A::GrassmannLieAlgHorMatrix) = (A.N, A.N) -KernelAbstractions.get_backend(B::GrassmannLieAlgHorMatrix) = KernelAbstractions.get_backend(B.B) +networkbackend(B::GrassmannLieAlgHorMatrix) = networkbackend(B.B) function Base.getindex(A::GrassmannLieAlgHorMatrix{T}, i::Integer, j::Integer) where {T} if i ≤ A.n diff --git a/src/arrays/lower_triangular.jl b/src/arrays/lower_triangular.jl index 4e3141b59..7f1e3bb9e 100644 --- a/src/arrays/lower_triangular.jl +++ b/src/arrays/lower_triangular.jl @@ -81,7 +81,7 @@ end function map_to_lo(A::AbstractMatrix{T}) where T n = size(A, 1) @assert size(A, 2) == n - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) S = KernelAbstractions.zeros(backend, T, n * (n - 1) ÷ 2) assign_Skew_val! = assign_Skew_val_kernel!(backend) for i in 2:n diff --git a/src/arrays/skew_symmetric.jl b/src/arrays/skew_symmetric.jl index 790b27312..f3ae66e4d 100644 --- a/src/arrays/skew_symmetric.jl +++ b/src/arrays/skew_symmetric.jl @@ -110,7 +110,7 @@ end function Base.:+(A::SkewSymMatrix{T}, B::AbstractMatrix{T}) where T @assert size(A) == size(B) - backend = KernelAbstractions.get_backend(B) + backend = networkbackend(B) addition! = addition_kernel!(backend) C = KernelAbstractions.allocate(backend, T, size(A)...) addition!(C, A.S, B; ndrange = size(A)) @@ -215,7 +215,7 @@ LinearAlgebra.rmul!(C::SkewSymMatrix, α::Real) = mul!(C, C, α) function Base.:*(A::SkewSymMatrix{T}, B::AbstractMatrix{T}) where T m1, m2 = size(B) @assert m1 == A.n - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) C = KernelAbstractions.allocate(backend, T, A.n, m2) skew_mat_mul! = skew_mat_mul_kernel!(backend) @@ -245,7 +245,7 @@ function Base.:*(A::SkewSymMatrix, b::AbstractVector{T}) where T end function Base.one(A::SkewSymMatrix{T}) where T - backend = KernelAbstractions.get_backend(A.S) + backend = networkbackend(A.S) unit_matrix = KernelAbstractions.zeros(backend, T, A.n, A.n) write_ones! = write_ones_kernel!(backend) write_ones!(unit_matrix, ndrange=A.n) @@ -290,8 +290,8 @@ function Base.zero(A::SkewSymMatrix) SkewSymMatrix(zero(A.S), A.n) end -function KernelAbstractions.get_backend(A::SkewSymMatrix) - KernelAbstractions.get_backend(A.S) +function networkbackend(A::SkewSymMatrix) + networkbackend(A.S) end function assign!(B::SkewSymMatrix{T}, C::SkewSymMatrix{T}) where T @@ -311,7 +311,7 @@ function map_to_Skew(A::AbstractMatrix{T}) where T n = size(A, 1) @assert size(A, 2) == n A_skew = T(.5)*(A - A') - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) S = if n != 1 KernelAbstractions.zeros(backend, T, n * (n - 1) ÷ 2) else diff --git a/src/arrays/stiefel_lie_algebra_horizontal.jl b/src/arrays/stiefel_lie_algebra_horizontal.jl index fdedf35ed..dd144b09f 100644 --- a/src/arrays/stiefel_lie_algebra_horizontal.jl +++ b/src/arrays/stiefel_lie_algebra_horizontal.jl @@ -273,8 +273,8 @@ function Base.zero(B::StiefelLieAlgHorMatrix) ) end -function KernelAbstractions.get_backend(B::StiefelLieAlgHorMatrix) - KernelAbstractions.get_backend(B.B) +function networkbackend(B::StiefelLieAlgHorMatrix) + networkbackend(B.B) end # assign funciton; also implement this for other arrays! @@ -302,7 +302,7 @@ function assign!(A::AbstractArray, B::AbstractArray) end function Base.one(B::StiefelLieAlgHorMatrix{T}) where T - backend = get_backend(B) + backend = networkbackend(B) oneB = KernelAbstractions.zeros(backend, T, B.N, B.N) write_ones! = write_ones_kernel!(backend) write_ones!(oneB; ndrange = B.N) diff --git a/src/arrays/stiefel_projection.jl b/src/arrays/stiefel_projection.jl index 253eab913..5633fd76a 100644 --- a/src/arrays/stiefel_projection.jl +++ b/src/arrays/stiefel_projection.jl @@ -29,7 +29,7 @@ Extract necessary information from `A` and build an instance of `StiefelProjecti Necessary information here referes to the backend, the data type and the size of the matrix. """ function StiefelProjection(A::AbstractMatrix{T}) where T - StiefelProjection(KernelAbstractions.get_backend(A), T, size(A)...) + StiefelProjection(networkbackend(A), T, size(A)...) end @doc raw""" @@ -58,7 +58,7 @@ true ``` """ function StiefelProjection(B::AbstractLieAlgHorMatrix{T}) where T - StiefelProjection(KernelAbstractions.get_backend(B), T, B.N, B.n) + StiefelProjection(networkbackend(B), T, B.N, B.n) end @kernel function assign_ones_for_stiefel_projection_kernel!(A::AbstractArray{T}) where T @@ -79,6 +79,6 @@ Base.vcat(E::StiefelProjection{T}, A::AbstractVecOrMat{T}) where {T<:Number} = v Base.hcat(A::AbstractVecOrMat{T}, E::StiefelProjection{T}) where {T<:Number} = hcat(A, E.A) Base.hcat(E::StiefelProjection{T}, A::AbstractVecOrMat{T}) where {T<:Number} = hcat(E.A, A) -function KernelAbstractions.get_backend(E::StiefelProjection) - KernelAbstractions.get_backend(E.A) +function networkbackend(E::StiefelProjection) + networkbackend(E.A) end \ No newline at end of file diff --git a/src/arrays/symmetric.jl b/src/arrays/symmetric.jl index 165097ebc..f6525323d 100644 --- a/src/arrays/symmetric.jl +++ b/src/arrays/symmetric.jl @@ -100,7 +100,7 @@ function map_to_S(A::AbstractMatrix{T}) where {T <: Number} n = size(A, 1) @assert size(A, 2) == n A_sym = T(.5)*(A + A') - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) S = KernelAbstractions.zeros(backend, T, n*(n+1)÷2) assign_S_val! = assign_S_val_kernel!(backend) for i in 1:n @@ -224,7 +224,7 @@ function LinearAlgebra.mul!(C::AbstractMatrix, A::SymmetricMatrix, B::AbstractMa @assert A.n == size(B, 1) @assert size(B, 2) == size(C, 2) @assert A.n == size(C, 1) - backend = KernelAbstractions.get_backend(A.S) + backend = networkbackend(A.S) symmetric_mat_mul! = symmetric_mat_mul_kernel!(backend) symmetric_mat_mul!(C, A.S, B, A.n, ndrange=size(C)) end @@ -244,13 +244,13 @@ end function LinearAlgebra.mul!(c::AbstractVector, A::SymmetricMatrix, b::AbstractVector) @assert A.n == length(c) == length(b) - backend = KernelAbstractions.get_backend(A.S) + backend = networkbackend(A.S) symmetric_vector_mul! = symmetric_vector_mul_kernel!(backend) symmetric_vector_mul!(c, A.S, b, A.n, ndrange=size(c)) end function Base.:*(A::SymmetricMatrix{T}, B::AbstractMatrix{T}) where T - backend = KernelAbstractions.get_backend(A.S) + backend = networkbackend(A.S) C = KernelAbstractions.allocate(backend, T, A.n, size(B, 2)) LinearAlgebra.mul!(C, A, B) C @@ -263,14 +263,14 @@ function Base.:*(A::SymmetricMatrix{T}, B::SymmetricMatrix{T}) where T end function Base.:*(A::SymmetricMatrix{T}, b::AbstractVector{T}) where T - backend = KernelAbstractions.get_backend(A.S) + backend = networkbackend(A.S) c = KernelAbstractions.allocate(backend, T, A.n) LinearAlgebra.mul!(c, A, b) c end function Base.one(A::SymmetricMatrix{T}) where T - backend = KernelAbstractions.get_backend(A.S) + backend = networkbackend(A.S) unit_matrix = KernelAbstractions.zeros(backend, T, A.n, A.n) write_ones! = write_ones_kernel!(backend) write_ones!(unit_matrix, ndrange=A.n) diff --git a/src/arrays/triangular.jl b/src/arrays/triangular.jl index ce3a909c4..4faa2fe9b 100644 --- a/src/arrays/triangular.jl +++ b/src/arrays/triangular.jl @@ -88,7 +88,7 @@ LinearAlgebra.mul!(C::AT, α::Real, A::AT) where AT <: AbstractTriangular = mul! LinearAlgebra.rmul!(C::AT, α::Real) where AT <: AbstractTriangular = mul!(C, C, α) function Base.one(A::AbstractTriangular{T}) where T - backend = KernelAbstractions.get_backend(A.S) + backend = networkbackend(A.S) unit_matrix = KernelAbstractions.zeros(backend, T, A.n, A.n) write_ones! = write_ones_kernel!(backend) write_ones!(unit_matrix, ndrange=A.n) @@ -132,8 +132,8 @@ function Base.zero(A::AT) where AT <: AbstractTriangular AT(zero(A.S), A.n) end -function KernelAbstractions.get_backend(A::AbstractTriangular) - KernelAbstractions.get_backend(A.S) +function networkbackend(A::AbstractTriangular) + networkbackend(A.S) end function assign!(B::AT, C::AT) where AT <: AbstractTriangular diff --git a/src/arrays/upper_triangular.jl b/src/arrays/upper_triangular.jl index 9134ef23c..07c2482d2 100644 --- a/src/arrays/upper_triangular.jl +++ b/src/arrays/upper_triangular.jl @@ -81,7 +81,7 @@ end function map_to_up(A::AbstractMatrix{T}) where T n = size(A, 1) @assert size(A, 2) == n - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) S = KernelAbstractions.zeros(backend, T, n * (n - 1) ÷ 2) assign_Skew_val! = assign_Skew_val_kernel!(backend) for i in 2:n diff --git a/src/backends/lux.jl b/src/backends/lux.jl index ddeebc3e6..60462d594 100644 --- a/src/backends/lux.jl +++ b/src/backends/lux.jl @@ -43,7 +43,7 @@ function apply(nn::LuxNeuralNetwork, x, params::NamedTuple) return y end -apply(nn::LuxNeuralNetwork, x) = apply(nn, x, nn.params) +apply(nn::LuxNeuralNetwork, x) = apply(nn, x, params(nn)) (nn::LuxNeuralNetwork)(x, args...) = apply(nn, x, args...) diff --git a/src/data_loader/batch.jl b/src/data_loader/batch.jl index b0d665f6d..a575ab31b 100644 --- a/src/data_loader/batch.jl +++ b/src/data_loader/batch.jl @@ -231,7 +231,7 @@ GeometricMachineLearning.convert_input_and_batch_indices_to_array(dl, batch, bat ``` """ function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT}, batch::Batch, batch_indices_tuple::Vector{Tuple{Int, Int}}) where {T, AT<:AbstractArray{T, 3}, BT<:NamedTuple{(:q, :p), Tuple{AT, AT}}} - backend = KernelAbstractions.get_backend(dl.input.q) + backend = networkbackend(dl.input.q) # the batch size is smaller for the last batch _batch_size = length(batch_indices_tuple) @@ -254,7 +254,7 @@ function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT}, batch:: end function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT}, batch::Batch, batch_indices_tuple::Vector{Tuple{Int, Int}}) where {T, BT<:AbstractArray{T, 3}} - backend = KernelAbstractions.get_backend(dl.input) + backend = networkbackend(dl.input) # the batch size is smaller for the last batch _batch_size = length(batch_indices_tuple) @@ -275,7 +275,7 @@ function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT}, batch:: end function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, Nothing, :RegularData}, batch::Batch, batch_indices_tuple::Vector{Tuple{Int, Int}}) where {T, AT<:AbstractArray{T, 3}, BT<:NamedTuple{(:q, :p), Tuple{AT, AT}}} - backend = KernelAbstractions.get_backend(dl.input.q) + backend = networkbackend(dl.input.q) # the batch size is smaller for the last batch _batch_size = length(batch_indices_tuple) @@ -292,7 +292,7 @@ function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, Nothing, end function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, Nothing, :RegularData}, batch::Batch, batch_indices_tuple::Vector{Tuple{Int, Int}}) where {T, BT<:AbstractArray{T, 3}} - backend = KernelAbstractions.get_backend(dl.input) + backend = networkbackend(dl.input) # the batch size is smaller for the last batch _batch_size = length(batch_indices_tuple) @@ -318,7 +318,7 @@ function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, OT}, ::B end function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, BT}, batch::Batch, batch_indices_tuple::Vector{Tuple{Int, Int}}) where {T, BT<:AbstractArray{T, 3}} - backend = KernelAbstractions.get_backend(dl.input) + backend = networkbackend(dl.input) # the batch size is smaller for the last batch _batch_size = length(batch_indices_tuple) diff --git a/src/data_loader/data_loader.jl b/src/data_loader/data_loader.jl index 26b841efa..70e39ea72 100644 --- a/src/data_loader/data_loader.jl +++ b/src/data_loader/data_loader.jl @@ -436,7 +436,7 @@ By default this inherits the autoencoder property form `dl`. See the docstring for [`DataLoader(data::AbstractArray{<:Number, 3})`](@ref). """ function DataLoader(dl::DataLoader{T1, <:QPTOAT, Nothing, Type}, - backend::KernelAbstractions.Backend=KernelAbstractions.get_backend(dl), + backend::KernelAbstractions.Backend=networkbackend(dl), T::DataType=T1; autoencoder = nothing ) where {T1, Type} @@ -456,7 +456,7 @@ function DataLoader(dl::DataLoader{T1, <:QPTOAT, Nothing, Type}, end new_input = - if backend == KernelAbstractions.get_backend(dl) + if backend == networkbackend(dl) input else map_to_new_backend(input, backend) @@ -473,7 +473,7 @@ function DataLoader(dl::DataLoader{T1, <:QPTOAT, Nothing, Type}, end function DataLoader(dl::DataLoader, T::DataType; kwargs...) - DataLoader(dl, KernelAbstractions.get_backend(dl), T; kwargs...) + DataLoader(dl, networkbackend(dl), T; kwargs...) end @doc raw""" @@ -486,7 +486,7 @@ This needs an instance of [`DataLoader`](@ref) that stores the *test data*. function accuracy(model::Chain, ps::NeuralNetworkParameters, dl::DataLoader{T, AT, BT}) where {T, T1<:Integer, AT<:AbstractArray{T}, BT<:AbstractArray{T1}} output_tensor = model(dl.input, ps) output_estimate = assign_output_estimate(output_tensor, dl.output_time_steps) - backend = KernelAbstractions.get_backend(output_estimate) + backend = networkbackend(output_estimate) tensor_of_maximum_elements = KernelAbstractions.zeros(backend, T1, size(output_estimate)...) ind = argmax(output_estimate, dims=1) # get tensor of maximum elements @@ -501,11 +501,11 @@ Compute the accuracy of a neural network classifier. This is like [`accuracy(::Chain, ::Tuple, ::DataLoader)`](@ref), but for a `NeuralNetwork`. """ -accuracy(nn::NeuralNetwork, dl::DataLoader) = accuracy(nn.model, nn.params, dl) +accuracy(nn::NeuralNetwork, dl::DataLoader) = accuracy(nn.model, params(nn), dl) Base.eltype(::DataLoader{T}) where T = T -KernelAbstractions.get_backend(dl::DataLoader) = KernelAbstractions.get_backend(dl.input) -function KernelAbstractions.get_backend(dl::DataLoader{T, <:QPT{T}}) where T - KernelAbstractions.get_backend(dl.input.q) +networkbackend(dl::DataLoader) = networkbackend(dl.input) +function networkbackend(dl::DataLoader{T, <:QPT{T}}) where T + networkbackend(dl.input.q) end \ No newline at end of file diff --git a/src/data_loader/mnist_utils.jl b/src/data_loader/mnist_utils.jl index 9506f2684..ef91f717f 100644 --- a/src/data_loader/mnist_utils.jl +++ b/src/data_loader/mnist_utils.jl @@ -41,7 +41,7 @@ onehotbatch(target) ``` """ function onehotbatch(target::AbstractVector{T}) where {T<:Integer} - backend = KernelAbstractions.get_backend(target) + backend = networkbackend(target) output = KernelAbstractions.zeros(backend, T, 10, length(target)) assign_val! = assign_val_kernel!(backend) assign_val!(output, target, ndrange=length(target)) @@ -131,7 +131,7 @@ The sizes of the first and second axis of the output of `split_and_flatten` are """ function split_and_flatten(input::AbstractArray{T, 3}; patch_length::Integer=7, number_of_patches::Integer=16) where T @assert size(input, 1) * size(input, 2) == (patch_length ^ 2) * number_of_patches - backend = KernelAbstractions.get_backend(input) + backend = networkbackend(input) output = KernelAbstractions.allocate(backend, T, patch_length^2, number_of_patches, size(input, 3)) split_and_flatten! = split_and_flatten_kernel!(backend) split_and_flatten!(output, input, patch_length, number_of_patches, ndrange=size(input)) diff --git a/src/data_loader/optimize.jl b/src/data_loader/optimize.jl index 67366847d..d2ad6dd07 100644 --- a/src/data_loader/optimize.jl +++ b/src/data_loader/optimize.jl @@ -96,11 +96,11 @@ function (o::Optimizer)(nn::NeuralNetwork, n_epochs::Integer, loss::NetworkLoss, _pullback::AbstractPullback = ZygotePullback(loss); show_progress = true) - Λ = GlobalSection(nn.params) + Λ = GlobalSection(params(nn)) progress_object = show_progress == true ? ProgressMeter.Progress(n_epochs; enabled=true) : nothing loss_array = zeros(n_epochs) for i in 1:n_epochs - loss_array[i] = optimize_for_one_epoch!(o, nn.model, nn.params, dl, batch, _pullback, Λ) + loss_array[i] = optimize_for_one_epoch!(o, nn.model, params(nn), dl, batch, _pullback, Λ) show_progress == true ? ProgressMeter.next!(progress_object; showvalues = [(:TrainingLoss, loss_array[i])]) : nothing end diff --git a/src/data_loader/tensor_assign.jl b/src/data_loader/tensor_assign.jl index 379820bba..b74a8a92f 100644 --- a/src/data_loader/tensor_assign.jl +++ b/src/data_loader/tensor_assign.jl @@ -58,9 +58,9 @@ If `prediction_window` is equal to `sequence_length`, then this is not needed. """ function assign_output_estimate(full_output::AbstractArray{T, 3}, prediction_window::Int) where T sys_dim, seq_length, batch_size = size(full_output) - backend = KernelAbstractions.get_backend(full_output) + backend = networkbackend(full_output) output_estimate = KernelAbstractions.allocate(backend, T, sys_dim, prediction_window, batch_size) - assign_output_estimate! = assign_output_estimate_kernel!(KernelAbstractions.get_backend(full_output)) + assign_output_estimate! = assign_output_estimate_kernel!(networkbackend(full_output)) assign_output_estimate!(output_estimate, full_output, seq_length, prediction_window, ndrange=size(output_estimate)) output_estimate end @@ -74,10 +74,10 @@ end end function augment_zeros(output_diff::AbstractArray{T, 3}, seq_length) where T sys_dim, prediction_window, batch_size = size(output_diff) - backend = KernelAbstractions.get_backend(output_diff) + backend = networkbackend(output_diff) dim, prediction_window, batch_size = size(output_diff) zero_tensor = KernelAbstractions.zeros(backend, T, sys_dim, seq_length, batch_size) - augment_zeros! = augment_zeros_kernel!(KernelAbstractions.get_backend(output_diff)) + augment_zeros! = augment_zeros_kernel!(networkbackend(output_diff)) augment_zeros!(zero_tensor, output_diff, seq_length, prediction_window, ndrange=size(output_diff)) zero_tensor end diff --git a/src/kernels/assign_q_and_p.jl b/src/kernels/assign_q_and_p.jl index ff312b1db..e58f8c293 100644 --- a/src/kernels/assign_q_and_p.jl +++ b/src/kernels/assign_q_and_p.jl @@ -38,7 +38,7 @@ end # The output is a `Tuple` containing `q` and `p`. # """ function assign_q_and_p(x::AbstractVector, N::Int) - backend = KernelAbstractions.get_backend(x) + backend = networkbackend(x) q = KernelAbstractions.allocate(backend, eltype(x), N) p = KernelAbstractions.allocate(backend, eltype(x), N) q_kernel! = assign_first_half!(backend) @@ -49,7 +49,7 @@ function assign_q_and_p(x::AbstractVector, N::Int) end function assign_q_and_p(x::AbstractMatrix, N::Int) - backend = KernelAbstractions.get_backend(x) + backend = networkbackend(x) q = KernelAbstractions.allocate(backend, eltype(x), N, size(x,2)) p = KernelAbstractions.allocate(backend, eltype(x), N, size(x,2)) q_kernel! = assign_first_half!(backend) @@ -60,7 +60,7 @@ function assign_q_and_p(x::AbstractMatrix, N::Int) end function assign_q_and_p(x::AbstractArray{T, 3}, N::Int) where T - backend = KernelAbstractions.get_backend(x) + backend = networkbackend(x) q = KernelAbstractions.allocate(backend, T, N, size(x,2), size(x,3)) p = KernelAbstractions.allocate(backend, T, N, size(x,2), size(x,3)) q_kernel! = assign_first_half!(backend) diff --git a/src/kernels/exponentials/tensor_exponential.jl b/src/kernels/exponentials/tensor_exponential.jl index a541283d4..db755de45 100644 --- a/src/kernels/exponentials/tensor_exponential.jl +++ b/src/kernels/exponentials/tensor_exponential.jl @@ -23,7 +23,7 @@ function init_output(B::AbstractArray{T, 3}) where T end function assign_ones!(output::AbstractArray{T, 3}) where T - backend = KernelAbstractions.get_backend(output) + backend = networkbackend(output) assign_ones_backend! = assign_ones_kernel!(backend) dims = (size(output,1), size(output,3)) assign_ones_backend!(output, ndrange=dims) diff --git a/src/kernels/inverses/cpu_inverse.jl b/src/kernels/inverses/cpu_inverse.jl index 128ede5b8..d05aa2673 100644 --- a/src/kernels/inverses/cpu_inverse.jl +++ b/src/kernels/inverses/cpu_inverse.jl @@ -10,7 +10,7 @@ end function cpu_inverse(A::AbstractArray) B = zero(A) - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) cpu_inverse! = cpu_inverse_kernel!(backend) cpu_inverse!(B, A, ndrange=size(A, 3)) @@ -34,7 +34,7 @@ function ChainRulesCore.rrule(::typeof(cpu_inverse), A::AbstractArray) function cpu_inverse_pullback(dB::AbstractArray) dA = zero(dB) - backend = KernelAbstractions.get_backend(dB) + backend = networkbackend(dB) cpu_inverse_pullback! = cpu_inverse_pullback_kernel!(backend) cpu_inverse_pullback!(dA, A, dB, ndrange=size(dB, 3)) diff --git a/src/kernels/inverses/inverse_2x2.jl b/src/kernels/inverses/inverse_2x2.jl index f5933863d..2212ab5ad 100644 --- a/src/kernels/inverses/inverse_2x2.jl +++ b/src/kernels/inverses/inverse_2x2.jl @@ -23,7 +23,7 @@ function tensor_inverse2!(out::AbstractArray{T, 3}, A::AbstractArray{T, 3}) wher @assert size(A, 1) == size(A, 2) == 2 @assert size(A) == size(out) - backend = get_backend(out) + backend = networkbackend(out) inv22! = inv22_kernel!(backend) inv22!(out, A, ndrange = size(A, 3)) diff --git a/src/kernels/inverses/inverse_3x3.jl b/src/kernels/inverses/inverse_3x3.jl index eed5c7a65..5a82dffb9 100644 --- a/src/kernels/inverses/inverse_3x3.jl +++ b/src/kernels/inverses/inverse_3x3.jl @@ -28,7 +28,7 @@ function tensor_inverse3!(out::AbstractArray{T, 3}, A::AbstractArray{T, 3}) wher @assert size(A, 1) == size(A, 2) == 3 @assert size(A) == size(out) - backend = get_backend(out) + backend = networkbackend(out) inv33! = inv33_kernel!(backend) inv33!(out, A, ndrange = size(A, 3)) diff --git a/src/kernels/inverses/inverse_4x4.jl b/src/kernels/inverses/inverse_4x4.jl index fe3f4156b..82fe84779 100644 --- a/src/kernels/inverses/inverse_4x4.jl +++ b/src/kernels/inverses/inverse_4x4.jl @@ -33,7 +33,7 @@ function tensor_inverse4!(out::AbstractArray{T, 3}, A::AbstractArray{T, 3}) wher @assert size(A, 1) == size(A, 2) == 4 @assert size(A) == size(out) - backend = get_backend(out) + backend = networkbackend(out) inv44! = inv44_kernel!(backend) inv44!(out, A, ndrange = size(A, 3)) diff --git a/src/kernels/inverses/inverse_5x5.jl b/src/kernels/inverses/inverse_5x5.jl index 70de89eb4..e363cf90d 100644 --- a/src/kernels/inverses/inverse_5x5.jl +++ b/src/kernels/inverses/inverse_5x5.jl @@ -42,7 +42,7 @@ function tensor_inverse5!(out::AbstractArray{T, 3}, A::AbstractArray{T, 3}) wher @assert size(A, 1) == size(A, 2) == 5 @assert size(A) == size(out) - backend = get_backend(out) + backend = networkbackend(out) inv55! = inv55_kernel!(backend) inv55!(out, A, ndrange = size(A, 3)) diff --git a/src/kernels/inverses/tensor_mat_skew_sym_assign.jl b/src/kernels/inverses/tensor_mat_skew_sym_assign.jl index 2aeb19339..26e7653bb 100644 --- a/src/kernels/inverses/tensor_mat_skew_sym_assign.jl +++ b/src/kernels/inverses/tensor_mat_skew_sym_assign.jl @@ -17,7 +17,7 @@ end function tensor_mat_skew_sym_assign!(C::AbstractArray{T, 3}, Z::AbstractArray{T, 3}, A::AbstractMatrix{T}) where {T} - backend = KernelAbstractions.get_backend(Z) + backend = networkbackend(Z) tensor_mat_skew_sym_assign_k! = tensor_mat_skew_sym_assign_kernel!(backend) @@ -83,7 +83,7 @@ tensor_mat_skew_sym_assign(Z, A) ``` """ function tensor_mat_skew_sym_assign(Z::AT, A::AbstractMatrix{T})::AT where {T, AT <: AbstractArray{T, 3}} - backend = KernelAbstractions.get_backend(Z) + backend = networkbackend(Z) C = KernelAbstractions.zeros(backend, T, size(Z, 2), size(Z, 2), size(Z, 3)) diff --git a/src/kernels/kernel_ad_routines/mat_tensor_mul.jl b/src/kernels/kernel_ad_routines/mat_tensor_mul.jl index 13d736eda..a65aacf99 100644 --- a/src/kernels/kernel_ad_routines/mat_tensor_mul.jl +++ b/src/kernels/kernel_ad_routines/mat_tensor_mul.jl @@ -53,7 +53,7 @@ function ChainRulesCore.rrule(::typeof(lo_mat_mul), S::AbstractVector{T}, A::Abs C = lo_mat_mul(S, A, n) function lo_mat_mul_pullback(dC::AbstractArray{T, 3}) f̄ = NoTangent() - backend = KernelAbstractions.get_backend(dC) + backend = networkbackend(dC) lower_da! = lower_da_kernel!(backend) lower_ds! = lower_ds_kernel!(backend) @@ -119,7 +119,7 @@ function ChainRulesCore.rrule(::typeof(up_mat_mul), S::AbstractVector{T}, A::Abs C = up_mat_mul(S, A, n) function up_mat_mul_pullback(dC::AbstractArray{T, 3}) f̄ = NoTangent() - backend = KernelAbstractions.get_backend(dC) + backend = networkbackend(dC) upper_da! = upper_da_kernel!(backend) upper_ds! = upper_ds_kernel!(backend) @@ -152,7 +152,7 @@ function ChainRulesCore.rrule(::typeof(skew_mat_mul), S::AbstractVector{T}, A::A C = skew_mat_mul(S, A, n) function skew_mat_mul_pullback(dC::AbstractArray{T, 3}) f̄ = NoTangent() - backend = KernelAbstractions.get_backend(dC) + backend = networkbackend(dC) lower_da! = lower_da_kernel!(backend) lower_ds! = lower_ds_kernel!(backend) upper_da! = upper_da_kernel!(backend) @@ -228,7 +228,7 @@ end function ChainRulesCore.rrule(::typeof(symmetric_mat_mul), S::AbstractVector{T}, A::AbstractArray{T, 3}, n::Int) where T C = symmetric_mat_mul(S, A, n) function symmetric_mat_mul_pullback(dC::AbstractArray{T, 3}) - backend = KernelAbstractions.get_backend(dC) + backend = networkbackend(dC) symmetric_da! = symmetric_da_kernel!(backend) symmetric_ds! = symmetric_ds_kernel!(backend) diff --git a/src/kernels/kernel_ad_routines/tensor_mat_mul.jl b/src/kernels/kernel_ad_routines/tensor_mat_mul.jl index b857a8175..a359ce633 100644 --- a/src/kernels/kernel_ad_routines/tensor_mat_mul.jl +++ b/src/kernels/kernel_ad_routines/tensor_mat_mul.jl @@ -64,7 +64,7 @@ end function ChainRulesCore.rrule(::typeof(symmetric_mat_right_mul), A::AbstractArray{T, 3}, S::AbstractVector{T}, n::Int) where T C = symmetric_mat_right_mul(A, S, n) function symmetric_mat_mul_pullback(dC::AbstractArray{T, 3}) - backend = KernelAbstractions.get_backend(dC) + backend = networkbackend(dC) symmetric_right_da! = symmetric_right_da_kernel!(backend) symmetric_right_ds! = symmetric_right_ds_kernel!(backend) diff --git a/src/kernels/kernel_ad_routines/tensor_mat_skew_sym_assign.jl b/src/kernels/kernel_ad_routines/tensor_mat_skew_sym_assign.jl index e24541ad0..5c899ad12 100644 --- a/src/kernels/kernel_ad_routines/tensor_mat_skew_sym_assign.jl +++ b/src/kernels/kernel_ad_routines/tensor_mat_skew_sym_assign.jl @@ -36,7 +36,7 @@ function ChainRulesCore.rrule(::typeof(tensor_mat_skew_sym_assign), Z::AbstractA B = tensor_mat_skew_sym_assign(Z, A) function tensor_mat_skew_sym_assign_pullback(dB::AbstractArray{T, 3}) f̄ = NoTangent() - backend = KernelAbstractions.get_backend(dB) + backend = networkbackend(dB) dz! = dz_kernel!(backend) da! = da_kernel!(backend) diff --git a/src/kernels/kernel_ad_routines/vec_tensor_mul.jl b/src/kernels/kernel_ad_routines/vec_tensor_mul.jl index 735a59d0c..b5a62cfda 100644 --- a/src/kernels/kernel_ad_routines/vec_tensor_mul.jl +++ b/src/kernels/kernel_ad_routines/vec_tensor_mul.jl @@ -22,7 +22,7 @@ end function tensor_scalar_product(x::AbstractArray{T, 3}, b_diff::AbstractArray{T, 3}) where T a_size = size(x, 1) - backend = KernelAbstractions.get_backend(x) + backend = networkbackend(x) a_diff = KernelAbstractions.zeros(backend, T, a_size) tensor_scalar_product! = tensor_scalar_product_kernel!(backend) tensor_scalar_product!(a_diff, x, b_diff, size(x, 2), size(x, 3), ndrange=size(a_diff)) diff --git a/src/kernels/mat_tensor_mul.jl b/src/kernels/mat_tensor_mul.jl index c11deabc8..f5bd8543e 100644 --- a/src/kernels/mat_tensor_mul.jl +++ b/src/kernels/mat_tensor_mul.jl @@ -24,7 +24,7 @@ function mat_tensor_mul!(C::AbstractArray{<:Number, 3}, A::AbstractMatrix, B::Ab @assert eltype(C) == eltype(A) == eltype(B) @assert size(A)[2] == size(B)[1] - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) kernel! = mat_tensor_mul_kernel!(backend) kernel!(C, A, B, ndrange=size(C)) end @@ -65,7 +65,7 @@ function mat_tensor_mul(A::AbstractMatrix, B::AbstractArray{<:Number, 3}) T = eltype(A) sizeA = size(A) sizeB = size(B) - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) C = KernelAbstractions.zeros(backend, T, sizeA[1], sizeB[2], sizeB[3]) mat_tensor_mul!(C, A, B) C @@ -87,7 +87,7 @@ end end function symmetric_mat_mul!(C::AbstractArray{T, 3}, S::AbstractVector{T}, B::AbstractArray{T, 3}, n::Int) where T - backend = KernelAbstractions.get_backend(C) + backend = networkbackend(C) symmetric_mat_mul_k! = symmetric_mat_mul_kernel!(backend) symmetric_mat_mul_k!(C, S, B, n, ndrange=size(C)) @@ -130,7 +130,7 @@ end end function lo_mat_mul!(C::AbstractArray{T, 3}, S::AbstractVector{T}, B::AbstractArray{T, 3}, n::Int) where T - backend = KernelAbstractions.get_backend(C) + backend = networkbackend(C) lo_mat_mul_k! = lo_mul_kernel!(backend) lo_mat_mul_k!(C, S, B, n, ndrange=size(C)) @@ -173,7 +173,7 @@ end end function up_mat_mul!(C::AbstractArray{T, 3}, S::AbstractVector{T}, B::AbstractArray{T, 3}, n::Int) where T - backend = KernelAbstractions.get_backend(C) + backend = networkbackend(C) up_mat_mul_k! = up_mul_kernel!(backend) up_mat_mul_k!(C, S, B, n, ndrange=size(C)) @@ -218,7 +218,7 @@ end end function skew_mat_mul!(C::AbstractArray{T, 3}, S::AbstractVector{T}, B::AbstractArray{T, 3}, n::Int) where T - backend = KernelAbstractions.get_backend(C) + backend = networkbackend(C) skew_mat_mul_k! = skew_mat_mul_kernel!(backend) skew_mat_mul_k!(C, S, B, n, ndrange=size(C)) diff --git a/src/kernels/tensor_mat_mul.jl b/src/kernels/tensor_mat_mul.jl index 86c1dbfee..3969ea0a1 100644 --- a/src/kernels/tensor_mat_mul.jl +++ b/src/kernels/tensor_mat_mul.jl @@ -21,7 +21,7 @@ The function [`tensor_mat_mul`](@ref) calls `tensor_mat_mul!` internally. function tensor_mat_mul!(C::AbstractArray{<:Number, 3}, A::AbstractArray{<:Number, 3}, B::AbstractMatrix) @assert size(A)[2] == size(B)[1] - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) kernel! = tensor_mat_mul_kernel!(backend) kernel!(C, A, B, ndrange=size(C)) end @@ -63,7 +63,7 @@ function tensor_mat_mul(A::AbstractArray{<:Number, 3}, B::AbstractMatrix) sizeA = size(A); sizeB = size(B) @assert sizeA[2] == sizeB[1] tensor_shape = (sizeA[1], sizeB[2], sizeA[3]) - backend = get_backend(A) + backend = networkbackend(A) C = KernelAbstractions.zeros(backend, T, tensor_shape...) tensor_mat_mul!(C, A, B) C @@ -87,7 +87,7 @@ end end function symmetric_mat_right_mul!(C::AbstractArray{T, 3}, B::AbstractArray{T, 3}, S::AbstractVector{T}, n::Int) where T - backend = KernelAbstractions.get_backend(C) + backend = networkbackend(C) symmetric_mat_right_mul_k! = symmetric_mat_right_mul_kernel!(backend) symmetric_mat_right_mul_k!(C, B, S, n, ndrange = size(C)) diff --git a/src/kernels/tensor_tensor_mul.jl b/src/kernels/tensor_tensor_mul.jl index 48fa6739b..628c4c78d 100644 --- a/src/kernels/tensor_tensor_mul.jl +++ b/src/kernels/tensor_tensor_mul.jl @@ -16,13 +16,13 @@ function tensor_tensor_mul!(c, a, b) @assert size(a)[3] == size(b)[3] @assert size(a)[2] == size(b)[1] - backend = KernelAbstractions.get_backend(a) + backend = networkbackend(a) kernel! = tensor_tensor_mul_kernel!(backend) kernel!(c, a, b, ndrange=size(c)) end function tensor_tensor_mul(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) C = KernelAbstractions.zeros(backend, T, size(A)[1], size(B)[2], size(A)[3]) tensor_tensor_mul!(C, A, B) C diff --git a/src/kernels/tensor_tensor_transpose_mul.jl b/src/kernels/tensor_tensor_transpose_mul.jl index c74c151d6..635944802 100644 --- a/src/kernels/tensor_tensor_transpose_mul.jl +++ b/src/kernels/tensor_tensor_transpose_mul.jl @@ -20,13 +20,13 @@ function tensor_tensor_transpose_mul!(C, A, B) @assert size(A)[3] == size(B)[3] @assert size(A)[2] == size(B)[2] - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) kernel! = tensor_tensor_transpose_mul_kernel!(backend) kernel!(C, A, B, ndrange=size(C)) end function tensor_tensor_transpose_mul(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) C = KernelAbstractions.zeros(backend, T, size(A)[1], size(B)[1], size(A)[3]) tensor_tensor_transpose_mul!(C, A, B) C diff --git a/src/kernels/tensor_transpose.jl b/src/kernels/tensor_transpose.jl index beb37368e..f2aba941d 100644 --- a/src/kernels/tensor_transpose.jl +++ b/src/kernels/tensor_transpose.jl @@ -15,14 +15,14 @@ function tensor_transpose!(C, A) @assert sizeA[1] == sizeC[2] @assert sizeA[2] == sizeC[1] - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) kernel! = tensor_transpose_kernel!(backend) kernel!(C, A, ndrange=size(C)) end function tensor_transpose(A::AbstractArray{T, 3}) where T sizeA = size(A) - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) C = KernelAbstractions.zeros(backend, T, sizeA[2], sizeA[1], sizeA[3]) tensor_transpose!(C, A) C diff --git a/src/kernels/tensor_transpose_mat_mul.jl b/src/kernels/tensor_transpose_mat_mul.jl index b753ea3c5..8f22a74d3 100644 --- a/src/kernels/tensor_transpose_mat_mul.jl +++ b/src/kernels/tensor_transpose_mat_mul.jl @@ -17,13 +17,13 @@ end function tensor_transpose_mat_mul!(C, A, B) @assert axes(A, 1) == axes(B, 1) - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) kernel! = tensor_transpose_mat_mul_kernel!(backend) kernel!(C, A, B, ndrange=size(C)) end function tensor_transpose_mat_mul(A::AbstractArray{T, 3}, B::AbstractMatrix{T}) where T - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) C = KernelAbstractions.zeros(backend, T, size(A)[2], size(B)[2], size(A)[3]) tensor_transpose_mat_mul!(C, A, B) C diff --git a/src/kernels/tensor_transpose_tensor_mul.jl b/src/kernels/tensor_transpose_tensor_mul.jl index 85da3afc2..5c1f78386 100644 --- a/src/kernels/tensor_transpose_tensor_mul.jl +++ b/src/kernels/tensor_transpose_tensor_mul.jl @@ -20,13 +20,13 @@ function tensor_transpose_tensor_mul!(C, A, B) @assert size(A)[3] == size(B)[3] @assert size(A)[1] == size(B)[1] - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) kernel! = tensor_transpose_tensor_mul_kernel!(backend) kernel!(C, A, B, ndrange=size(C)) end function tensor_transpose_tensor_mul(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) C = KernelAbstractions.zeros(backend, T, size(A)[2], size(B)[2], size(A)[3]) tensor_transpose_tensor_mul!(C, A, B) C diff --git a/src/kernels/tensor_transpose_tensor_transpose_mul.jl b/src/kernels/tensor_transpose_tensor_transpose_mul.jl index 55e7c5c81..6905d1713 100644 --- a/src/kernels/tensor_transpose_tensor_transpose_mul.jl +++ b/src/kernels/tensor_transpose_tensor_transpose_mul.jl @@ -20,13 +20,13 @@ function tensor_transpose_tensor_transpose_mul!(C, A, B) @assert axes(A, 3) == axes(B, 3) @assert axes(A, 1) == axes(B, 2) - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) kernel! = tensor_transpose_tensor_transpose_mul_kernel!(backend) kernel!(C, A, B, ndrange=size(C)) end function tensor_transpose_tensor_transpose_mul(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T - backend = KernelAbstractions.get_backend(A) + backend = networkbackend(A) C = KernelAbstractions.zeros(backend, T, size(A)[2], size(B)[1], size(A)[3]) tensor_transpose_tensor_transpose_mul!(C, A, B) C diff --git a/src/kernels/vec_tensor_mul.jl b/src/kernels/vec_tensor_mul.jl index 9df4a0e5f..f0d92405e 100644 --- a/src/kernels/vec_tensor_mul.jl +++ b/src/kernels/vec_tensor_mul.jl @@ -5,7 +5,7 @@ end function vec_tensor_mul(a::AbstractVector{T}, x::AbstractArray{T, 3}) where T b = similar(x) - backend = KernelAbstractions.get_backend(x) + backend = networkbackend(x) vec_tensor_mul! = vec_tensor_mul_kernel!(backend) vec_tensor_mul!(b, a, x, ndrange=size(x)) b diff --git a/src/layers/bias_layer.jl b/src/layers/bias_layer.jl index 85ba6cc33..b92671010 100644 --- a/src/layers/bias_layer.jl +++ b/src/layers/bias_layer.jl @@ -6,7 +6,7 @@ function BiasLayer(M::Int) BiasLayer{M, M}() end -function initialparameters(::BiasLayer{M, M}, backend::Backend, ::Type{T}; rng::AbstractRNG = Random.default_rng(), init_bias = GlorotUniform()) where {M, T} +function initialparameters(rng::AbstractRNG, init_bias::AbstractNeuralNetworks.Initializer, ::BiasLayer{M, M}, backend::Backend, ::Type{T}) where {M, T} q_part = KernelAbstractions.zeros(backend, T, M÷2) p_part = KernelAbstractions.zeros(backend, T, M÷2) init_bias(rng, q_part) diff --git a/src/layers/classification.jl b/src/layers/classification.jl index 6d146fbd8..fe9c5d1c0 100644 --- a/src/layers/classification.jl +++ b/src/layers/classification.jl @@ -70,7 +70,7 @@ function ClassificationLayer(input_dim::Integer, output_dim::Integer, activation ClassificationLayer{input_dim, output_dim, average, typeof(activation)}(activation) end -function initialparameters(::ClassificationLayer{M, N}, device::KernelAbstractions.Backend, T::Type; rng::Random.AbstractRNG=Random.default_rng(), init_weight! = GlorotUniform()) where {M, N} +function initialparameters(rng::Random.AbstractRNG, init_weight!::AbstractNeuralNetworks.Initializer, ::ClassificationLayer{M, N}, device::KernelAbstractions.Backend, ::Type{T}) where {M, N, T} weight = KernelAbstractions.allocate(device, T, N, M) init_weight!(rng, weight) (weight=weight, ) diff --git a/src/layers/grassmann_layer.jl b/src/layers/grassmann_layer.jl index 2f5826b7a..381309095 100644 --- a/src/layers/grassmann_layer.jl +++ b/src/layers/grassmann_layer.jl @@ -16,8 +16,10 @@ function GrassmannLayer(n::Integer, N::Integer) GrassmannLayer{n, N}() end -function AbstractNeuralNetworks.initialparameters(d::GrassmannLayer{N,M}, backend::KernelAbstractions.Backend, ::Type{T}; rng::AbstractRNG=Random.default_rng()) where {M,N,T} - (weight = N > M ? rand(backend, rng, GrassmannManifold{T}, N, M) : rand(backend, rng, GrassmannManifold{T}, M, N), ) +function initialparameters(rng::AbstractRNG, init::AbstractNeuralNetworks.Initializer, ::GrassmannLayer{N,M}, backend::NeuralNetworkBackend, ::Type{T}) where {M,N,T} + weight = N > M ? KernelAbstractions.allocate(backend, T, N, M) : KernelAbstractions.allocate(backend, T, M, N) + init(rng, weight) + (weight = GrassmannManifold(assign_columns(typeof(weight)(qr!(weight).Q), size(weight)...)), ) end function parameterlength(::GrassmannLayer{M, N}) where {M, N} diff --git a/src/layers/linear_symplectic_attention.jl b/src/layers/linear_symplectic_attention.jl index 5b15baec9..0bad20400 100644 --- a/src/layers/linear_symplectic_attention.jl +++ b/src/layers/linear_symplectic_attention.jl @@ -13,9 +13,10 @@ The coefficients of a [`LinearSymplecticAttention`](@ref) layer is a [`Symmetric ```jldoctest using GeometricMachineLearning +using GeometricMachineLearning: params l = LinearSymplecticAttentionQ(3, 5) -ps = initialparameters(l, CPU(), Float32) +ps = params(NeuralNetwork(Chain(l))).L1 typeof(ps.A) <: SymmetricMatrix @@ -69,7 +70,7 @@ end parameterlength(l::LinearSymplecticAttention) = (l.seq_length + 1) * l.seq_length ÷ 2 -function initialparameters(l::LinearSymplecticAttention, backend::KernelAbstractions.Backend, T::Type; rng::AbstractRNG=Random.default_rng(), initializer::AbstractNeuralNetworks.AbstractInitializer=GlorotUniform()) +function initialparameters(rng::AbstractRNG, initializer::AbstractNeuralNetworks.Initializer, l::LinearSymplecticAttention, backend::KernelAbstractions.Backend, T::Type) S = KernelAbstractions.allocate(backend, T, parameterlength(l)) initializer(rng, S) (A = SymmetricMatrix(S, l.seq_length), ) diff --git a/src/layers/multi_head_attention.jl b/src/layers/multi_head_attention.jl index 35a490e27..0c8f0e942 100644 --- a/src/layers/multi_head_attention.jl +++ b/src/layers/multi_head_attention.jl @@ -38,7 +38,7 @@ function parameterlength(d::MultiHeadAttention{M, M, true}) where M Int(3*M^2 - 3*M*(M + d.n_heads)/(2*d.n_heads)) end -function initialparameters(d::MultiHeadAttention{M, M, false}, backend::KernelAbstractions.Backend, T::Type; rng::AbstractRNG=Random.default_rng(), initializer::AbstractNeuralNetworks.AbstractInitializer=GlorotUniform()) where {M} +function initialparameters(rng::AbstractRNG, initializer::AbstractNeuralNetworks.Initializer, d::MultiHeadAttention{M, M, false}, backend::KernelAbstractions.Backend, T::Type) where {M} # number of "hidden" dimension (dimension of projection) Dₕ = M ÷ d.n_heads # projections for queries, keys and values. @@ -70,7 +70,7 @@ function initialparameters(d::MultiHeadAttention{M, M, false}, backend::KernelAb end -function initialparameters(d::MultiHeadAttention{M, M, true}, backend::KernelAbstractions.Backend, T::Type; rng::AbstractRNG=Random.default_rng(), initializer::AbstractNeuralNetworks.AbstractInitializer=GlorotUniform()) where {M} +function initialparameters(rng::AbstractRNG, initializer::AbstractNeuralNetworks.Initializer, d::MultiHeadAttention{M, M, true}, backend::KernelAbstractions.Backend, ::Type{T}) where {M, T} # number of "hidden" dimension (dimension of projection) Dₕ = M ÷ d.n_heads # projections for queries, keys and vectors. @@ -82,13 +82,13 @@ function initialparameters(d::MultiHeadAttention{M, M, true}, backend::KernelAbs key = Symbol("head_"*string(head)) PQ = merge(PQ, - NamedTuple{(key, )}((rand(backend, rng, StiefelManifold{T}, M, Dₕ), )) + NamedTuple{(key, )}(values(initialparameters(rng, initializer, StiefelLayer{M, Dₕ}(), backend, T))) ) PK = merge(PK, - NamedTuple{(key, )}((rand(backend, rng, StiefelManifold{T}, M, Dₕ), )) + NamedTuple{(key, )}(values(initialparameters(rng, initializer, StiefelLayer{M, Dₕ}(), backend, T))) ) PV = merge(PV, - NamedTuple{(key, )}((rand(backend, rng, StiefelManifold{T}, M, Dₕ), )) + NamedTuple{(key, )}(values(initialparameters(rng, initializer, StiefelLayer{M, Dₕ}(), backend, T))) ) end (PQ=PQ, PK=PK, PV=PV) diff --git a/src/layers/psd_like_layer.jl b/src/layers/psd_like_layer.jl index f8117542e..74e672612 100644 --- a/src/layers/psd_like_layer.jl +++ b/src/layers/psd_like_layer.jl @@ -26,8 +26,10 @@ function parameterlength(::PSDLayer{M, N}) where {M, N} N > M ? Int(M2 * (N2 - (M2 + 1) / 2)) : Int(N2 * (M2 - (N2 + 1) / 2)) end -function initialparameters(::PSDLayer{M, N}, backend::KernelAbstractions.Backend, T::Type; rng::AbstractRNG=Random.default_rng()) where {M, N} - (weight = N > M ? rand(backend, rng, StiefelManifold{T}, N ÷ 2, M ÷ 2) : rand(backend, rng, StiefelManifold{T}, M ÷ 2, N ÷ 2), ) +function initialparameters(rng::AbstractRNG, initializer::AbstractNeuralNetworks.Initializer, ::PSDLayer{M, N}, backend::KernelAbstractions.Backend, T::Type) where {M, N} + weight = N > M ? KernelAbstractions.allocate(backend, T, N ÷ 2, M ÷ 2) : KernelAbstractions.allocate(backend, T, M ÷ 2, N ÷ 2) + initializer(rng, weight) + (weight = StiefelManifold(assign_columns(typeof(weight)(qr!(weight).Q), size(weight)...)), ) end function (::PSDLayer{M, N})(qp::NamedTuple{(:q, :p), Tuple{AT1, AT2}}, ps::NamedTuple) where {M, N, AT1 <: AbstractArray, AT2 <: AbstractArray} diff --git a/src/layers/resnet.jl b/src/layers/resnet.jl index bc63d2cbf..af3915f33 100644 --- a/src/layers/resnet.jl +++ b/src/layers/resnet.jl @@ -22,7 +22,7 @@ function ResNetLayer(dim::Integer, activation=identity; use_bias::Bool=true) return ResNetLayer{dim, dim, use_bias, typeof(activation)}(activation) end -function initialparameters(::ResNetLayer{M, M, use_bias}, backend::KernelAbstractions.Backend, T::Type; rng::Random.AbstractRNG=Random.default_rng(), init_weight = GlorotUniform(), init_bias = ZeroInitializer()) where {M, use_bias} +function initialparameters(rng::Random.AbstractRNG, init_weight::AbstractNeuralNetworks.Initializer, ::ResNetLayer{M, M, use_bias}, backend::KernelAbstractions.Backend, ::Type{T}; init_bias = ZeroInitializer()) where {M, use_bias, T} if use_bias weight = KernelAbstractions.allocate(backend, T, M, M) bias = KernelAbstractions.allocate(backend, T, M) diff --git a/src/layers/stiefel_layer.jl b/src/layers/stiefel_layer.jl index 7944b1ea0..fd6286bde 100644 --- a/src/layers/stiefel_layer.jl +++ b/src/layers/stiefel_layer.jl @@ -7,8 +7,10 @@ function StiefelLayer(n::Integer, N::Integer) StiefelLayer{n, N}() end -function AbstractNeuralNetworks.initialparameters(::StiefelLayer{M,N}, backend::KernelAbstractions.Backend, ::Type{T}; rng::AbstractRNG=Random.default_rng()) where {M,N,T} - (weight = N > M ? rand(backend, rng, StiefelManifold{T}, N, M) : rand(backend, rng, StiefelManifold{T}, M, N), ) +function initialparameters(rng::AbstractRNG, initializer::AbstractNeuralNetworks.Initializer, ::StiefelLayer{M,N}, backend::KernelAbstractions.Backend, ::Type{T}) where {M,N,T} + weight = N > M ? KernelAbstractions.allocate(backend, T, N, M) : KernelAbstractions.allocate(backend, T, M, N) + initializer(rng, weight) + (weight = StiefelManifold(assign_columns(typeof(weight)(qr!(weight).Q), size(weight)...)),) end function parameterlength(::StiefelLayer{M, N}) where {M, N} diff --git a/src/layers/sympnets.jl b/src/layers/sympnets.jl index bd07a51e1..ba6e2c151 100644 --- a/src/layers/sympnets.jl +++ b/src/layers/sympnets.jl @@ -189,7 +189,7 @@ function Gradient(dim::Int, dim2::Int=dim, activation = identity; full_grad::Boo end end -function initialparameters(d::GradientLayer{M, M}, backend::Backend, ::Type{T}; rng::AbstractRNG = Random.default_rng(), init_weight = GlorotUniform(), init_bias = ZeroInitializer(), init_scale = GlorotUniform()) where {M, T} +function initialparameters(rng::AbstractRNG, init_weight::AbstractNeuralNetworks.Initializer, d::GradientLayer{M, M}, backend::Backend, ::Type{T}; init_bias = ZeroInitializer(), init_scale = GlorotUniform()) where {M, T} K = KernelAbstractions.allocate(backend, T, d.second_dim÷2, M÷2) b = KernelAbstractions.allocate(backend, T, d.second_dim÷2) a = KernelAbstractions.allocate(backend, T, d.second_dim÷2) @@ -199,13 +199,13 @@ function initialparameters(d::GradientLayer{M, M}, backend::Backend, ::Type{T}; return (weight=K, bias=b, scale=a) end -function initialparameters(::ActivationLayer{M, M}, backend::Backend, ::Type{T}; rng::AbstractRNG = Random.default_rng(), init_scale = GlorotUniform()) where {M, T} +function initialparameters(rng::AbstractRNG, init_scale::AbstractNeuralNetworks.Initializer, ::ActivationLayer{M, M}, backend::Backend, ::Type{T}) where {M, T} a = KernelAbstractions.zeros(backend, T, M ÷ 2) init_scale(rng, a) return (scale = a,) end -function initialparameters(::LinearLayer{M, M}, backend::Backend, ::Type{T}; rng::AbstractRNG = Random.default_rng(), init_weight = GlorotUniform()) where {M, T} +function initialparameters(rng::AbstractRNG, init_weight::AbstractNeuralNetworks.Initializer, ::LinearLayer{M, M}, backend::Backend, ::Type{T}) where {M, T} S = KernelAbstractions.allocate(backend, T, (M ÷ 2) * (M ÷ 2 + 1) ÷ 2) init_weight(rng, S) (weight=SymmetricMatrix(S, M ÷ 2), ) diff --git a/src/layers/volume_preserving_attention.jl b/src/layers/volume_preserving_attention.jl index 13cc372ad..62d1f26e5 100644 --- a/src/layers/volume_preserving_attention.jl +++ b/src/layers/volume_preserving_attention.jl @@ -68,13 +68,13 @@ function parameterlength(::VolumePreservingAttention{M, M, :arbitrary}) where {M M ^2 end -function initialparameters(d::VolumePreservingAttention{M, M, :skew_sym}, backend::KernelAbstractions.Backend, T::Type; rng::AbstractRNG=Random.default_rng(), initializer!::AbstractNeuralNetworks.AbstractInitializer=GlorotUniform()) where {M} +function initialparameters(rng::AbstractRNG, initializer!::AbstractNeuralNetworks.Initializer, d::VolumePreservingAttention{M, M, :skew_sym}, backend::KernelAbstractions.Backend, T::Type) where {M} V = KernelAbstractions.allocate(backend, T, parameterlength(d)) initializer!(rng, V) (A = SkewSymMatrix(V, M), ) end -function initialparameters(::VolumePreservingAttention{M, M, :arbitrary}, backend::KernelAbstractions.Backend, T::Type; rng::AbstractRNG=Random.default_rng(), initializer!::AbstractNeuralNetworks.AbstractInitializer=GlorotUniform()) where {M} +function initialparameters(rng::AbstractRNG, initializer!::AbstractNeuralNetworks.Initializer, ::VolumePreservingAttention{M, M, :arbitrary}, backend::KernelAbstractions.Backend, T::Type) where {M} A = KernelAbstractions.allocate(backend, T, M, M) initializer!(rng, A) (A = A, ) diff --git a/src/layers/volume_preserving_feedforward.jl b/src/layers/volume_preserving_feedforward.jl index eac95d0d4..561836387 100644 --- a/src/layers/volume_preserving_feedforward.jl +++ b/src/layers/volume_preserving_feedforward.jl @@ -126,7 +126,7 @@ end parameterlength(::VolumePreservingFeedForwardLayer{M, M, :no_bias}) where M = M * (M - 1) ÷ 2 parameterlength(::VolumePreservingFeedForwardLayer{M, M, :bias}) where M = M * (M - 1) ÷ 2 + M -function initialparameters(d::VolumePreservingLowerLayer{M, M, :bias}, backend::Backend, ::Type{T}; rng::AbstractRNG = Random.default_rng(), init_weight! = GlorotUniform(), init_bias! = ZeroInitializer()) where {M, T} +function initialparameters(rng::AbstractRNG, init_weight!::AbstractNeuralNetworks.Initializer, d::VolumePreservingLowerLayer{M, M, :bias}, backend::Backend, ::Type{T}; init_bias! = ZeroInitializer()) where {M, T} S = KernelAbstractions.allocate(backend, T, parameterlength(d) - M) b = KernelAbstractions.allocate(backend, T, M) init_weight!(rng, S) @@ -135,7 +135,7 @@ function initialparameters(d::VolumePreservingLowerLayer{M, M, :bias}, backend:: (weight = LowerTriangular(S, M), bias = b) end -function initialparameters(d::VolumePreservingUpperLayer{M, M, :bias}, backend::Backend, ::Type{T}; rng::AbstractRNG = Random.default_rng(), init_weight! = GlorotUniform(), init_bias! = ZeroInitializer()) where {M, T} +function initialparameters(rng::AbstractRNG, init_weight!::AbstractNeuralNetworks.Initializer, d::VolumePreservingUpperLayer{M, M, :bias}, backend::Backend, ::Type{T}; init_bias! = ZeroInitializer()) where {M, T} S = KernelAbstractions.allocate(backend, T, parameterlength(d) - M) b = KernelAbstractions.allocate(backend, T, M) init_weight!(rng, S) @@ -144,14 +144,14 @@ function initialparameters(d::VolumePreservingUpperLayer{M, M, :bias}, backend:: (weight = UpperTriangular(S, M), bias = b) end -function initialparameters(d::VolumePreservingLowerLayer{M, M, :no_bias}, backend::Backend, ::Type{T}; rng::AbstractRNG = Random.default_rng(), init_weight! = GlorotUniform(), init_bias! = ZeroInitializer()) where {M, T} +function initialparameters(rng::AbstractRNG, init_weight!::AbstractNeuralNetworks.Initializer, d::VolumePreservingLowerLayer{M, M, :no_bias}, backend::Backend, ::Type{T}; init_bias! = ZeroInitializer()) where {M, T} S = KernelAbstractions.allocate(backend, T, parameterlength(d)) init_weight!(rng, S) (weight = LowerTriangular(S, M), ) end -function initialparameters(d::VolumePreservingUpperLayer{M, M, :no_bias}, backend::Backend, ::Type{T}; rng::AbstractRNG = Random.default_rng(), init_weight! = GlorotUniform(), init_bias! = ZeroInitializer()) where {M, T} +function initialparameters(rng::AbstractRNG, init_weight!::AbstractNeuralNetworks.Initializer, d::VolumePreservingUpperLayer{M, M, :no_bias}, backend::Backend, ::Type{T}; init_bias! = ZeroInitializer()) where {M, T} S = KernelAbstractions.allocate(backend, T, parameterlength(d)) init_weight!(rng, S) diff --git a/src/loss/losses.jl b/src/loss/losses.jl index 257a68e1f..1b21f6773 100644 --- a/src/loss/losses.jl +++ b/src/loss/losses.jl @@ -192,7 +192,7 @@ This loss does not have any parameters. struct AutoEncoderLoss <: NetworkLoss end function (loss::AutoEncoderLoss)(nn::NeuralNetwork, input) - loss(nn.model, nn.params, input, input) + loss(nn.model, params(nn), input, input) end function (loss::AutoEncoderLoss)(model::Union{Chain, AbstractExplicitLayer}, ps::Union{NeuralNetworkParameters, NamedTuple}, input) diff --git a/src/manifolds/abstract_manifold.jl b/src/manifolds/abstract_manifold.jl index 1f958c197..f58ee6603 100644 --- a/src/manifolds/abstract_manifold.jl +++ b/src/manifolds/abstract_manifold.jl @@ -11,7 +11,7 @@ abstract type Manifold{T} <: AbstractMatrix{T} end end function assign_columns(Q::AbstractMatrix{T}, N::Integer, n::Integer) where T - backend = KernelAbstractions.get_backend(Q) + backend = networkbackend(Q) Y = KernelAbstractions.allocate(backend, T, N, n) assign_columns! = assign_columns_kernel!(backend) assign_columns!(Y, Q, ndrange=size(Y)) diff --git a/src/manifolds/grassmann_manifold.jl b/src/manifolds/grassmann_manifold.jl index a2d3fae37..81e5e687d 100644 --- a/src/manifolds/grassmann_manifold.jl +++ b/src/manifolds/grassmann_manifold.jl @@ -76,7 +76,7 @@ See the documentation for [`global_section(Y::StiefelManifold{T}) where T`](@ref """ function global_section(Y::GrassmannManifold{T}) where T N, n = size(Y) - backend = KernelAbstractions.get_backend(Y) + backend = networkbackend(Y) A = KernelAbstractions.allocate(backend, T, N, N-n) randn!(A) A = A - Y.A * (Y.A' * A) diff --git a/src/manifolds/stiefel_manifold.jl b/src/manifolds/stiefel_manifold.jl index fb19860ec..5caf4afc0 100644 --- a/src/manifolds/stiefel_manifold.jl +++ b/src/manifolds/stiefel_manifold.jl @@ -129,7 +129,7 @@ qr!(A).Q """ function global_section(Y::StiefelManifold{T}) where T N, n = size(Y) - backend = KernelAbstractions.get_backend(Y) + backend = networkbackend(Y) A = KernelAbstractions.allocate(backend, T, N, N-n) randn!(A) A = A - Y.A * (Y.A' * A) diff --git a/src/map_to_cpu.jl b/src/map_to_cpu.jl index 4ff8d0b35..8ca85e0de 100644 --- a/src/map_to_cpu.jl +++ b/src/map_to_cpu.jl @@ -29,6 +29,6 @@ function map_to_cpu(A::SymmetricMatrix{T}) where T end function map_to_cpu(nn::NeuralNetwork{AT, MT, <:Any, BT}) where {AT, MT, BT} - ps = map_to_cpu(nn.params) + ps = map_to_cpu(params(nn)) NeuralNetwork{AT, MT, typeof(ps), BT}(nn.architecture, nn.model, ps, nn.backend) end \ No newline at end of file diff --git a/src/optimizers/bfgs_cache.jl b/src/optimizers/bfgs_cache.jl index 77987deae..1bbcf26b4 100644 --- a/src/optimizers/bfgs_cache.jl +++ b/src/optimizers/bfgs_cache.jl @@ -43,13 +43,14 @@ end # """ function initialize_hessian_inverse(B::AbstractArray{T}) where T length_of_array = length(vec(B)) - backend = KernelAbstractions.get_backend(B) + backend = networkbackend(B) H = KernelAbstractions.zeros(backend, T, length_of_array, length_of_array) assign_diagonal_ones! = assign_diagonal_ones_kernel!(backend) assign_diagonal_ones!(H, ndrange=length_of_array) H end +setup_bfgs_cache(ps::NeuralNetworkParameters) = setup_bfgs_cache(params(ps)) setup_bfgs_cache(ps::NamedTuple) = apply_toNT(setup_bfgs_cache, ps) setup_bfgs_cache(ps::Tuple) = Tuple([setup_bfgs_cache(x) for x in ps]) setup_bfgs_cache(B::AbstractArray) = BFGSCache(B) \ No newline at end of file diff --git a/src/optimizers/manifold_related/retractions.jl b/src/optimizers/manifold_related/retractions.jl index ebf14e6e3..3dab15aad 100644 --- a/src/optimizers/manifold_related/retractions.jl +++ b/src/optimizers/manifold_related/retractions.jl @@ -87,7 +87,7 @@ See [`geodesic(::StiefelLieAlgHorMatrix)`](@ref). function geodesic(B::GrassmannLieAlgHorMatrix) T = eltype(B) E = StiefelProjection(B) - backend = KernelAbstractions.get_backend(B) + backend = networkbackend(B) zero_mat = KernelAbstractions.zeros(backend, T, B.n, B.n) B̂ = hcat(vcat(zero_mat, B.B), E) B̄ = hcat(vcat(one(zero_mat), zero_mat), vcat(zero(B.B'), -B.B'))' @@ -176,7 +176,7 @@ See [`cayley(::StiefelLieAlgHorMatrix)`](@ref). function cayley(B::GrassmannLieAlgHorMatrix) T = eltype(B) E = StiefelProjection(B) - backend = KernelAbstractions.get_backend(B) + backend = networkbackend(B) 𝕆 = KernelAbstractions.zeros(backend, T, B.n, B.n) 𝕀_small = one(𝕆) 𝕀_small2 = hcat(vcat(𝕀_small, 𝕆), vcat(𝕆, 𝕀_small)) diff --git a/src/optimizers/optimizer.jl b/src/optimizers/optimizer.jl index ab90857b0..c306fbc36 100644 --- a/src/optimizers/optimizer.jl +++ b/src/optimizers/optimizer.jl @@ -65,7 +65,7 @@ function Optimizer(method::OptimizerMethod, nn_params::Union{NeuralNetworkParame end function Optimizer(method::OptimizerMethod, nn::NeuralNetwork; kwargs...) - Optimizer(method, nn.params; kwargs...) + Optimizer(method, params(nn); kwargs...) end Optimizer(nn::NeuralNetwork, m::OptimizerMethod; kwargs...) = Optimizer(m, nn; kwargs...) @@ -107,6 +107,11 @@ function optimization_step!(o::Optimizer, λY::NamedTuple, ps::NeuralNetworkPara end end +# take care of Zygote idiosyncrasies +function optimization_step!(o::Optimizer, λY::NamedTuple, ps::NeuralNetworkParameters, dx::NamedTuple{(:params,), Tuple{NT}}) where {NT <: NamedTuple} + optimization_step!(o, λY, ps, dx.params) +end + @doc raw""" optimization_step!(o, λY, ps, dx) @@ -127,9 +132,10 @@ All the arguments are given as `NamedTuple`s as the neural network weights are ```jldoctest using GeometricMachineLearning +using GeometricMachineLearning: params l = StiefelLayer(3, 5) -ps = initialparameters(l, Float32) +ps = params(NeuralNetwork(Chain(l), Float32)).L1 cache = apply_toNT(MomentumCache, ps) o = Optimizer(MomentumOptimizer(), cache, 0, geodesic) λY = GlobalSection(ps) @@ -147,8 +153,6 @@ _test_nt(λY) & _test_nt(ps) & _test_nt(cache) & _test_nt(dx) true ``` -Note that we used `initialparameters` here instead of `NeuralNetwork` (as we do usually). - # Extended help The derivatives `dx` here are usually obtained via an AD routine by differentiating a loss function, i.e. `dx` is ``\nabla_xL``. """ diff --git a/src/optimizers/optimizer_caches.jl b/src/optimizers/optimizer_caches.jl index b80951b9c..9f9dd14ab 100644 --- a/src/optimizers/optimizer_caches.jl +++ b/src/optimizers/optimizer_caches.jl @@ -109,12 +109,12 @@ setup_gradient_cache(B::AbstractArray{<:Number}) = GradientCache(B) function Base.zero(Y::StiefelManifold{T}) where T N, n = size(Y) - backend = KernelAbstractions.get_backend(Y.A) + backend = networkbackend(Y.A) zeros(backend, StiefelLieAlgHorMatrix{T}, N, n) end function Base.zero(Y::GrassmannManifold{T}) where T N, n = size(Y) - backend = KernelAbstractions.get_backend(Y.A) + backend = networkbackend(Y.A) zeros(backend, GrassmannLieAlgHorMatrix{T}, N, n) end diff --git a/src/training/train.jl b/src/training/train.jl index 69b2712b9..56507b196 100644 --- a/src/training/train.jl +++ b/src/training/train.jl @@ -1,9 +1,9 @@ const DEFAULT_NRUNS = 1000 # The loss gradient function working for all types of arguments -loss_gradient(nn::NeuralNetwork{<:Architecture}, ti::AbstractTrainingMethod, data::AbstractTrainingData, index_batch, params = nn.params) = Zygote.gradient(p -> loss(ti, nn, data, index_batch, p), params)[1] +loss_gradient(nn::NeuralNetwork{<:Architecture}, ti::AbstractTrainingMethod, data::AbstractTrainingData, index_batch, params = params(nn)) = Zygote.gradient(p -> loss(ti, nn, data, index_batch, p), params)[1] -loss_gradient(loss, index_batch, params = nn.params) = Zygote.gradient(p -> loss(p, index_batch), params)[1] +loss_gradient(loss, index_batch, params = params(nn)) = Zygote.gradient(p -> loss(p, index_batch), params)[1] #loss_gradient(nn::SymbolicNeuralNetwork, ti::AbstractTrainingMethod, data::AbstractTrainingData, index_batch, params = params(nn)) = #mapreduce(args->∇loss_single(ti, nn, get_loss(ti, nn, data, args)..., params), +, index_batch) diff --git a/src/training_method/hnn_exact_method.jl b/src/training_method/hnn_exact_method.jl index 3b3fffcc2..27c06807b 100644 --- a/src/training_method/hnn_exact_method.jl +++ b/src/training_method/hnn_exact_method.jl @@ -2,12 +2,12 @@ struct HnnExactMethod <: HnnTrainingMethod end ExactHnn(;sqdist = sqeuclidean) = TrainingMethod{HnnExactMethod, DerivativePhaseSpaceSymbol, SampledData, typeof(sqdist)}(sqdist) -function loss_single(::TrainingMethod{HnnExactMethod}, nn::NeuralNetwork{<:HamiltonianNeuralNetwork}, qₙ, pₙ, q̇ₙ, ṗₙ, params = nn.params) +function loss_single(::TrainingMethod{HnnExactMethod}, nn::NeuralNetwork{<:HamiltonianNeuralNetwork}, qₙ, pₙ, q̇ₙ, ṗₙ, params = params(nn)) dH = vectorfield(nn, [qₙ...,pₙ...], params) sqeuclidean(dH[1],q̇ₙ) + sqeuclidean(dH[2],ṗₙ) end get_loss(::TrainingMethod{HnnExactMethod}, ::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:DerivativePhaseSpaceSymbol}}, args) = (Zygote.ignore_derivatives(get_data(data,:q, args...)), Zygote.ignore_derivatives(get_data(data,:p, args...)), Zygote.ignore_derivatives(get_data(data,:q̇, args...)), Zygote.ignore_derivatives(get_data(data,:ṗ, args...))) -loss(ti::TrainingMethod{HnnExactMethod}, nn::NeuralNetwork{<:HamiltonianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:DerivativePhaseSpaceSymbol}}, index_batch = eachindex(ti, data), params = nn.params) = +loss(ti::TrainingMethod{HnnExactMethod}, nn::NeuralNetwork{<:HamiltonianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:DerivativePhaseSpaceSymbol}}, index_batch = eachindex(ti, data), params = params(nn)) = mapreduce(args->loss_single(Zygote.ignore_derivatives(ti), nn, get_loss(ti, nn, data, args)..., params), +, index_batch) diff --git a/src/training_method/lnn_exact_method.jl b/src/training_method/lnn_exact_method.jl index 1c5f8d144..4179be243 100644 --- a/src/training_method/lnn_exact_method.jl +++ b/src/training_method/lnn_exact_method.jl @@ -2,11 +2,11 @@ struct LnnExactMethod <: LnnTrainingMethod end ExactLnn(;sqdist = sqeuclidean) = TrainingMethod{LnnExactMethod, PosVeloAccSymbol, SampledData, typeof(sqdist)}(sqdist) -function loss_single(::TrainingMethod{LnnExactMethod}, nn::AbstractNeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, q̇ₙ, q̈ₙ, params = nn.params) +function loss_single(::TrainingMethod{LnnExactMethod}, nn::AbstractNeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, q̇ₙ, q̈ₙ, params = params(nn)) abs(sum(∇q∇q̇L(nn,qₙ, q̇ₙ, params))) #inv(∇q̇∇q̇L(nn, qₙ, q̇ₙ, params))*(∇qL(nn, qₙ, q̇ₙ, params) - ∇q∇q̇L(nn, qₙ, q̇ₙ, params)) end -loss(ti::TrainingMethod{<:LnnExactMethod}, nn::AbstractNeuralNetwork{<:LagrangianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:PosVeloAccSymbol}}, index_batch = eachindex(ti, data), params = nn.params) = +loss(ti::TrainingMethod{<:LnnExactMethod}, nn::AbstractNeuralNetwork{<:LagrangianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:PosVeloAccSymbol}}, index_batch = eachindex(ti, data), params = params(nn)) = mapreduce(args->loss_single(Zygote.ignore_derivatives(ti), nn, get_loss(ti, nn, data, args)..., params),+, index_batch) get_loss(::TrainingMethod{<:LnnExactMethod}, ::AbstractNeuralNetwork{<:LagrangianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:PosVeloAccSymbol}}, args) = (get_data(data, :q,args...), get_data(data, :q̇, args...), get_data(data, :q̈, args...)) \ No newline at end of file diff --git a/src/training_method/symplectic_euler.jl b/src/training_method/symplectic_euler.jl index ccc0df327..5f55eebad 100644 --- a/src/training_method/symplectic_euler.jl +++ b/src/training_method/symplectic_euler.jl @@ -9,19 +9,19 @@ SEulerA(;sqdist = sqeuclidean) = TrainingMethod{SymplecticEulerA, PhaseSpaceSymb SEulerB(;sqdist = sqeuclidean) = TrainingMethod{SymplecticEulerB, PhaseSpaceSymbol, TrajectoryData, typeof(sqdist)}(sqdist) -function loss_single(::TrainingMethod{SymplecticEulerA}, nn::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, qₙ, qₙ₊₁, pₙ, pₙ₊₁, Δt, params = nn.params) +function loss_single(::TrainingMethod{SymplecticEulerA}, nn::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, qₙ, qₙ₊₁, pₙ, pₙ₊₁, Δt, params = params(nn)) dH = vectorfield(nn, [qₙ₊₁...,pₙ...], params) sqeuclidean(dH[1],(qₙ₊₁-qₙ)/Δt) + sqeuclidean(dH[2],(pₙ₊₁-pₙ)/Δt) end -function loss_single(::TrainingMethod{SymplecticEulerB}, nn::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, qₙ, qₙ₊₁, pₙ, pₙ₊₁, Δt, params = nn.params) +function loss_single(::TrainingMethod{SymplecticEulerB}, nn::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, qₙ, qₙ₊₁, pₙ, pₙ₊₁, Δt, params = params(nn)) dH = vectorfield(nn, [qₙ...,pₙ₊₁...], params) sqeuclidean(dH[1],(qₙ₊₁-qₙ)/Δt) + sqeuclidean(dH[2],(pₙ₊₁-pₙ)/Δt) end get_loss(::TrainingMethod{<:SymplecticEuler}, ::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:PhaseSpaceSymbol}}, args) = (get_data(data,:q, args...), get_data(data,:q, next(args...)...), get_data(data,:p, args...), get_data(data,:p,next(args...)...), get_Δt(data)) -loss(ti::TrainingMethod{<:SymplecticEuler}, nn::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:PhaseSpaceSymbol}}, index_batch = eachindex(ti, data), params = nn.params) = +loss(ti::TrainingMethod{<:SymplecticEuler}, nn::AbstractNeuralNetwork{<:HamiltonianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:PhaseSpaceSymbol}}, index_batch = eachindex(ti, data), params = params(nn)) = mapreduce(args->loss_single(Zygote.ignore_derivatives(ti), nn, get_loss(ti, nn, data, args)..., params),+, index_batch) min_length_batch(::SymplecticEuler) = 2 diff --git a/src/training_method/sympnet_basic_method.jl b/src/training_method/sympnet_basic_method.jl index ea06bee0c..6a37c382b 100644 --- a/src/training_method/sympnet_basic_method.jl +++ b/src/training_method/sympnet_basic_method.jl @@ -2,7 +2,7 @@ struct BasicSympNetMethod <: SympNetTrainingMethod end BasicSympNet(;sqdist = sqeuclidean) = TrainingMethod{BasicSympNetMethod, PhaseSpaceSymbol, TrajectoryData, typeof(sqdist)}(sqdist) -function loss_single(::TrainingMethod{BasicSympNetMethod}, nn::AbstractNeuralNetwork{<:SympNet}, qₙ, pₙ, qₙ₊₁, pₙ₊₁, params = nn.params) +function loss_single(::TrainingMethod{BasicSympNetMethod}, nn::AbstractNeuralNetwork{<:SympNet}, qₙ, pₙ, qₙ₊₁, pₙ₊₁, params = params(nn)) q̃ₙ₊₁,p̃ₙ₊₁ = nn([qₙ...,pₙ...],params) sqeuclidean(q̃ₙ₊₁,qₙ₊₁) + sqeuclidean(p̃ₙ₊₁,pₙ₊₁) end @@ -10,7 +10,7 @@ end get_loss(::TrainingMethod{<:BasicSympNetMethod}, ::AbstractNeuralNetwork{<:SympNet}, data::TrainingData{<:DataSymbol{<:PhaseSpaceSymbol}}, args) = (Zygote.ignore_derivatives(get_data(data,:q, args...)), Zygote.ignore_derivatives(get_data(data,:p, args...)), Zygote.ignore_derivatives(get_data(data,:q, next(args...)...)), Zygote.ignore_derivatives(get_data(data,:p, next(args...)...))) -loss(ti::TrainingMethod{<:BasicSympNetMethod}, nn::AbstractNeuralNetwork{<:SympNet}, data::TrainingData{<:DataSymbol{<:PhaseSpaceSymbol}}, index_batch = eachindex(ti, data), params = nn.params) = +loss(ti::TrainingMethod{<:BasicSympNetMethod}, nn::AbstractNeuralNetwork{<:SympNet}, data::TrainingData{<:DataSymbol{<:PhaseSpaceSymbol}}, index_batch = eachindex(ti, data), params = params(nn)) = mapreduce(args->loss_single(Zygote.ignore_derivatives(ti), nn, get_loss(ti, nn, data, args)..., params),+, index_batch) min_length_batch(::BasicSympNetMethod) = 2 diff --git a/src/training_method/variational_method.jl b/src/training_method/variational_method.jl index 5712fc17b..851d52d54 100644 --- a/src/training_method/variational_method.jl +++ b/src/training_method/variational_method.jl @@ -5,15 +5,15 @@ struct VariationalTrapezMethod <: VariationalMethod end VariaMidPoint(;sqdist = sqeuclidean) = TrainingMethod{VariationalMidPointMethod, PositionSymbol, TrajectoryData, typeof(sqdist)}(sqdist) # discrete langrangian -discrete_lagrangian(::TrainingMethod{VariationalMidPointMethod}, nn::NeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, qₙ₊₁, Δt, params = nn.params) = nn([(qₙ₊₁+qₙ)/2..., (qₙ₊₁-qₙ)/Δt...], params) +discrete_lagrangian(::TrainingMethod{VariationalMidPointMethod}, nn::NeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, qₙ₊₁, Δt, params = params(nn)) = nn([(qₙ₊₁+qₙ)/2..., (qₙ₊₁-qₙ)/Δt...], params) # gradient of discrete Lagrangian -DL(ti::TrainingMethod{<:VariationalMethod}, nn::NeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, qₙ₊₁, Δt, params = nn.params) = Zygote.gradient((qₙ,qₙ₊₁)->discrete_lagrangian(ti, nn, qₙ, qₙ₊₁, Δt, params), qₙ, qₙ₊₁) -DL₁(ti::TrainingMethod{<:VariationalMethod}, nn::NeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, qₙ₊₁, Δt, params = nn.params) = DL(ti, nn, qₙ, qₙ₊₁, Δt, params)[1:length(qₙ)] -DL₂(ti::TrainingMethod{<:VariationalMethod}, nn::NeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, qₙ₊₁, Δt, params = nn.params) = DL(ti, nn, qₙ, qₙ₊₁, Δt, params)[1+length(qₙ):end] +DL(ti::TrainingMethod{<:VariationalMethod}, nn::NeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, qₙ₊₁, Δt, params = params(nn)) = Zygote.gradient((qₙ,qₙ₊₁)->discrete_lagrangian(ti, nn, qₙ, qₙ₊₁, Δt, params), qₙ, qₙ₊₁) +DL₁(ti::TrainingMethod{<:VariationalMethod}, nn::NeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, qₙ₊₁, Δt, params = params(nn)) = DL(ti, nn, qₙ, qₙ₊₁, Δt, params)[1:length(qₙ)] +DL₂(ti::TrainingMethod{<:VariationalMethod}, nn::NeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, qₙ₊₁, Δt, params = params(nn)) = DL(ti, nn, qₙ, qₙ₊₁, Δt, params)[1+length(qₙ):end] -function loss_single(ti::TrainingMethod{<:VariationalMethod}, nn::AbstractNeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, qₙ₊₁, qₙ₊₂, Δt, params = nn.params) +function loss_single(ti::TrainingMethod{<:VariationalMethod}, nn::AbstractNeuralNetwork{<:LagrangianNeuralNetwork}, qₙ, qₙ₊₁, qₙ₊₂, Δt, params = params(nn)) DL1 = DL₁(ti, nn, qₙ₊₁, qₙ₊₂, Δt, params) DL2 = DL₂(ti, nn, qₙ, qₙ₊₁, Δt,params) sqeuclidean(DL1,-DL2) @@ -22,6 +22,6 @@ end get_loss(::TrainingMethod{<:VariationalMidPointMethod}, ::AbstractNeuralNetwork{<:LagrangianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:PositionSymbol}}, args) = (get_data(data,:q, args...), get_data(data,:q, next(args...)...), get_data(data,:q,next(next(args...)...)...), get_Δt(data)) -loss(ti::TrainingMethod{<:VariationalMidPointMethod}, nn::AbstractNeuralNetwork{<:LagrangianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:PositionSymbol}}, index_batch = eachindex(ti, data), params = nn.params) = +loss(ti::TrainingMethod{<:VariationalMidPointMethod}, nn::AbstractNeuralNetwork{<:LagrangianNeuralNetwork}, data::TrainingData{<:DataSymbol{<:PositionSymbol}}, index_batch = eachindex(ti, data), params = params(nn)) = mapreduce(args->loss_single(Zygote.ignore_derivatives(ti), nn, get_loss(ti, nn, data, args)..., params),+, index_batch) min_length_batch(::VariationalMethod) = 3 \ No newline at end of file diff --git a/test/attention_layer/apply_multi_head_attention.jl b/test/attention_layer/apply_multi_head_attention.jl index bb2aa001b..247dc1af7 100644 --- a/test/attention_layer/apply_multi_head_attention.jl +++ b/test/attention_layer/apply_multi_head_attention.jl @@ -11,7 +11,7 @@ function compare_attention_to_mha(N, batch_size=10, T=Float32) model₃ = MultiHeadAttention(N, 1, add_connection=true) model₄ = Attention(N, softmax, add_connection=true) - ps₂ = initialparameters(model₂, CPU(), T) + ps₂ = NeuralNetwork(model₂, CPU(), T).params ps₁ = (PQ=(head_1=ps₂.PQ,), PK=(head_1=ps₂.PK,), PV=(head_1=typeof(ps₂.PK)(I(N)),)) mat = rand(T, N, N) diff --git a/test/attention_layer/attention_setup.jl b/test/attention_layer/attention_setup.jl index b43a8c66b..beaa0b98d 100644 --- a/test/attention_layer/attention_setup.jl +++ b/test/attention_layer/attention_setup.jl @@ -5,13 +5,13 @@ import Random Random.seed!(1234) function volume_preserving_attention_tests(N, T=Float32) - model₁ = VolumePreservingAttention(N, N, skew_sym = false) - model₂ = VolumePreservingAttention(N, N, skew_sym = true) + model₁ = Chain(VolumePreservingAttention(N, N, skew_sym = false)) + model₂ = Chain(VolumePreservingAttention(N, N, skew_sym = true)) - ps₁ = initialparameters(model₁, CPU(), T) - ps₂ = initialparameters(model₂, CPU(), T) - @test typeof(ps₁.A) <: AbstractMatrix{T} - @test typeof(ps₂.A) <: SkewSymMatrix{T} + ps₁ = NeuralNetwork(model₁, CPU(), T).params + ps₂ = NeuralNetwork(model₂, CPU(), T).params + @test typeof(ps₁.L1.A) <: AbstractMatrix{T} + @test typeof(ps₂.L1.A) <: SkewSymMatrix{T} # check if the layers are volume preserving A = randn(T, N, N) diff --git a/test/cuda/resnet.jl b/test/cuda/resnet.jl index c9e09e63a..cc6a2892e 100644 --- a/test/cuda/resnet.jl +++ b/test/cuda/resnet.jl @@ -2,6 +2,6 @@ using GeometricMachineLearning, CUDA, Test, KernelAbstractions using GeometricMachineLearning: ResNet model = ResNet(4, tanh) -ps = initialparameters(CUDABackend(), Float32, model) +ps = NeuralNetwork(CUDABackend(), Float32, model).params @test typeof(ps.weight) <: CuArray @test typeof(ps.bias) <: CuArray diff --git a/test/cuda/stiefel_manifold.jl b/test/cuda/stiefel_manifold.jl index 5c5680155..b0fe47bf1 100644 --- a/test/cuda/stiefel_manifold.jl +++ b/test/cuda/stiefel_manifold.jl @@ -15,7 +15,7 @@ end function test_optimizer(T, N, n) model = Chain(StiefelLayer(N, n), StiefelLayer(n, n)) - ps = initialparameters(backend, T, model) + ps = NeuralNetwork(model, backend, T).params @test typeof(ps[1].weight) <: StiefelManifold{T, <:CuArray{T, 2}} @test typeof(ps[2].weight) <: StiefelManifold{T, <:CuArray{T, 2}} diff --git a/test/data_loader/data_loader_optimization_step.jl b/test/data_loader/data_loader_optimization_step.jl index f016d98d2..820962f7e 100644 --- a/test/data_loader/data_loader_optimization_step.jl +++ b/test/data_loader/data_loader_optimization_step.jl @@ -12,7 +12,7 @@ function test_data_loader(sys_dim, n_time_steps, n_params, T=Float32) # first argument is sys_dim, second is number of heads, third is number of units model = Transformer(dl.input_dim, 2, 1) - ps = initialparameters(model, CPU(), T) + ps = NeuralNetwork(model, CPU(), T).params loss = GeometricMachineLearning.TransformerLoss(n_time_steps) dx = Zygote.gradient(ps -> loss(model, ps, dl.input, dl.input), ps)[1] ps_copy = deepcopy(ps) diff --git a/test/data_loader/mnist_utils.jl b/test/data_loader/mnist_utils.jl index d2791448a..1666216a6 100644 --- a/test/data_loader/mnist_utils.jl +++ b/test/data_loader/mnist_utils.jl @@ -62,11 +62,11 @@ function test_optimizer_for_classification_layer(; dim₁=28, dim₂=28, number_ dl = DataLoader(generate_dummy_mnist(dim₁, dim₂, number_images, T)...; patch_length=patch_length) activation_function(x) = tanh.(x) - model = ClassificationLayer(patch_length * patch_length, 10, activation_function) + model = Chain(ClassificationLayer(patch_length * patch_length, 10, activation_function)) - ps = initialparameters(model, CPU(), T) + ps = NeuralNetwork(model, CPU(), T).params loss = FeedForwardLoss() - loss_dl(model::GeometricMachineLearning.AbstractExplicitLayer, ps::Union{Tuple, NamedTuple}, dl::DataLoader) = loss(model, ps, dl.input, dl.output) + loss_dl(model::GeometricMachineLearning.Chain, ps::Union{Tuple, NamedTuple, NeuralNetworkParameters}, dl::DataLoader) = loss(model, ps, dl.input, dl.output) loss₁ = loss_dl(model, ps, dl) opt = Optimizer(GradientOptimizer(), ps) diff --git a/test/data_loader/optimizer_functor_with_adam.jl b/test/data_loader/optimizer_functor_with_adam.jl index 52d312f93..38660cf0c 100644 --- a/test/data_loader/optimizer_functor_with_adam.jl +++ b/test/data_loader/optimizer_functor_with_adam.jl @@ -16,7 +16,6 @@ function create_dummy_mnist(; T=Float32, dim₁=6, dim₂=6, n_images=10) rand(T, dim₁, dim₂, n_images), Int.(floor.(10*rand(T, n_images))) end - function test_optimization_with_adam(;T=Float32, dim₁=6, dim₂=6, n_images=10, patch_length=3) dl = DataLoader(create_dummy_mnist(T=T, dim₁=dim₁, dim₂=dim₂, n_images=n_images)...; patch_length=patch_length) @@ -26,7 +25,7 @@ function test_optimization_with_adam(;T=Float32, dim₁=6, dim₂=6, n_images=10 # input dim is dim₁ / patch_length * dim₂ / pach_length; the transformer is called with dim₁ / patch_length and two layers model = Chain(Transformer(dl.input_dim, patch_length, 2; Stiefel=true), ClassificationLayer(dl.input_dim, 10, σ)) - ps = initialparameters(model, CPU(), Float32) + ps = NeuralNetwork(model, CPU(), Float32).params loss = FeedForwardLoss() diff --git a/test/layers/classification.jl b/test/layers/classification.jl index 4158f400c..b2cd65e4b 100644 --- a/test/layers/classification.jl +++ b/test/layers/classification.jl @@ -4,8 +4,8 @@ import Random Random.seed!(1234) function test_set_up_and_application(T=Float32, sys_dim=49, output_dim=10, seq_length=16, batch_size=32; average=false) - d = ClassificationLayer(sys_dim, output_dim, σ, average=average) - ps = initialparameters(d, CPU(), T) + d = Chain(ClassificationLayer(sys_dim, output_dim, σ, average=average)) + ps = NeuralNetwork(d, CPU(), T).params output₁ = d(rand(T, sys_dim, seq_length), ps) output₂ = d(rand(T, sys_dim, seq_length, batch_size), ps) @test size(output₁) == (10, 1) diff --git a/test/layers/gradient_layer_tests.jl b/test/layers/gradient_layer_tests.jl index 9896f4d53..a30e7d984 100644 --- a/test/layers/gradient_layer_tests.jl +++ b/test/layers/gradient_layer_tests.jl @@ -6,8 +6,8 @@ import Random, Zygote Random.seed!(1234) function test_gradient_layer_application(T, M, N, batch_size=10) - dummy_model = GradientLayerQ(M, N, tanh) - ps = initialparameters(dummy_model, CPU(), T) + dummy_model = Chain(GradientLayerQ(M, N, tanh)) + ps = NeuralNetwork(dummy_model, CPU(), T).params x = rand(T, M) x_applied = dummy_model(x, ps) @@ -22,7 +22,7 @@ end function test_gradient_layer_derivative_and_update(T, M, N, batch_size=10) dummy_model = Chain(GradientLayerP(M, N, tanh), GradientLayerQ(M, N, tanh)) - ps = initialparameters(dummy_model, CPU(), T) + ps = NeuralNetwork(dummy_model, CPU(), T).params o = Optimizer(AdamOptimizer(T(0.1), T(.9), T(0.999), T(3e-7)), ps) # test for vector @@ -37,7 +37,6 @@ function test_gradient_layer_derivative_and_update(T, M, N, batch_size=10) optimization_step!(o, λY, ps, gs) end - types = (Float32, Float64) for T in types for M in 4:2:10 diff --git a/test/layers/manifold_layers.jl b/test/layers/manifold_layers.jl index 488dc0478..69a5c2abe 100644 --- a/test/layers/manifold_layers.jl +++ b/test/layers/manifold_layers.jl @@ -5,10 +5,10 @@ Random.seed!(1234) function stiefel_layer_test(T, M, N, tol=1f-1) model = Chain(StiefelLayer(M, N), StiefelLayer(N, N)) - ps = initialparameters(model, T) + ps = NeuralNetwork(model, T).params o = Optimizer(AdamOptimizer(T(1f0), T(5f-1), T(5f-1), T(3f-7)),ps) - dx = (L1 = (weight=rand(T,N,M),), L2 = (weight=rand(T,N,N),)) + dx = (L1 = (weight = rand(T, N, M),), L2 = (weight=rand(T, N, N),)) ps_copy = deepcopy(ps) λY = GlobalSection(ps) optimization_step!(o, λY, ps, dx) @@ -21,7 +21,7 @@ end function grassmann_layer_test(T, M, N, tol=1f-1) model = Chain(GrassmannLayer(M, N), StiefelLayer(N, N)) - ps = initialparameters(model, T) + ps = NeuralNetwork(model, T).params o = Optimizer(AdamOptimizer(T(1f0), T(5f-1), T(5f-1), T(3f-7)),ps) dx = (L1 = (weight=rand(T,N,M),), L2 = (weight=rand(T,N,N),)) diff --git a/test/layers/resnet_tests.jl b/test/layers/resnet_tests.jl index 80817ddab..ddea427ff 100644 --- a/test/layers/resnet_tests.jl +++ b/test/layers/resnet_tests.jl @@ -5,10 +5,10 @@ import Random Random.seed!(1234) function test_resnet(M, batch_size=10, T=Float32) - model₁ = ResNetLayer(M, tanh, use_bias=false) - model₂ = ResNetLayer(M, tanh, use_bias=true) - ps₁ = initialparameters(model₁, CPU(), T) - ps₂ = initialparameters(model₂, CPU(), T) + model₁ = Chain(ResNetLayer(M, tanh, use_bias=false)) + model₂ = Chain(ResNetLayer(M, tanh, use_bias=true)) + ps₁ = NeuralNetwork(model₁, CPU(), T).params + ps₂ = NeuralNetwork(model₂, CPU(), T).params @test parameterlength(model₁) == M^2 @test parameterlength(model₂) == M*(M+1) A = randn(T, M, M*2, batch_size) diff --git a/test/layers/sympnet_layers_test.jl b/test/layers/sympnet_layers_test.jl index de05ec5e8..5fbd83361 100644 --- a/test/layers/sympnet_layers_test.jl +++ b/test/layers/sympnet_layers_test.jl @@ -10,9 +10,9 @@ function sympnet_tests(N, N2=2*N, second_dim=10, third_dim=10, T=Float32) model₁ = Chain(LinearLayerQ(N), LinearLayerP(N)) model₂ = Chain(ActivationLayerQ(N, tanh), ActivationLayerP(N, tanh)) model₃ = Chain(GradientLayerQ(N, N2, tanh), GradientLayerP(N, N2, tanh)) - ps₁ = initialparameters(model₁, CPU(), T) - ps₂ = initialparameters(model₂, CPU(), T) - ps₃ = initialparameters(model₃, CPU(), T) + ps₁ = NeuralNetwork(model₁, CPU(), T).params + ps₂ = NeuralNetwork(model₂, CPU(), T).params + ps₃ = NeuralNetwork(model₃, CPU(), T).params # evaluate functions x_vec = rand(T, N) diff --git a/test/layers/sympnet_upscaling.jl b/test/layers/sympnet_upscaling.jl index ad0479b54..ce670d38b 100644 --- a/test/layers/sympnet_upscaling.jl +++ b/test/layers/sympnet_upscaling.jl @@ -2,7 +2,7 @@ using GeometricMachineLearning, Test, Zygote function test_symplecticity(N=4, N2=20, T=Float32) model = Chain(PSDLayer(N, N2), GradientQ(N2, 2*N2, tanh), GradientP(N2, 2*N2, tanh), PSDLayer(N2, N)) - ps = initialparameters(model, CPU(), T) + ps = NeuralNetwork(model, CPU(), T).params x = rand(T, N) ten = rand(T, N, N) # the first and last PSD layer need to have the same weight! (else they map to a different symplectic potential) diff --git a/test/layers/volume_preserving_feedforward.jl b/test/layers/volume_preserving_feedforward.jl index c3e532131..34e82c2ee 100644 --- a/test/layers/volume_preserving_feedforward.jl +++ b/test/layers/volume_preserving_feedforward.jl @@ -14,10 +14,10 @@ function test_volume_preserving_feedforward(dim₁ = 5; T::Type=Float32) layer₃ = VolumePreservingUpperLayer(dim₁; use_bias = false) layer₄ = VolumePreservingUpperLayer(dim₁; use_bias = true) - ps₁ = initialparameters(layer₁, CPU(), T) - ps₂ = initialparameters(layer₂, CPU(), T) - ps₃ = initialparameters(layer₃, CPU(), T) - ps₄ = initialparameters(layer₄, CPU(), T) + ps₁ = NeuralNetwork(Chain(layer₁), CPU(), T).params.L1 + ps₂ = NeuralNetwork(Chain(layer₂), CPU(), T).params.L1 + ps₃ = NeuralNetwork(Chain(layer₃), CPU(), T).params.L1 + ps₄ = NeuralNetwork(Chain(layer₄), CPU(), T).params.L1 # test if application to matrix and tensor gives same result test_vector = rand(T, dim₁) diff --git a/test/linear_symplectic_attention.jl b/test/linear_symplectic_attention.jl index 035ef9b81..f5308ca57 100644 --- a/test/linear_symplectic_attention.jl +++ b/test/linear_symplectic_attention.jl @@ -7,8 +7,8 @@ Random.seed!(123) function test_application_of_lsa(n::Integer=4, seq_length::Integer=5, T=Float64) l₁ = LinearSymplecticAttentionQ(n, seq_length) l₂ = LinearSymplecticAttentionP(n, seq_length) - ps₁ = initialparameters(l₁, CPU(), T) - ps₂ = initialparameters(l₂, CPU(), T) + ps₁ = NeuralNetwork(Chain(l₁), CPU(), T).params.L1 + ps₂ = NeuralNetwork(Chain(l₂), CPU(), T).params.L1 # test for NamedTuple as input nt = (q = rand(T, n, seq_length), p = rand(T, n, seq_length)) diff --git a/test/network_losses/losses_and_optimization.jl b/test/network_losses/losses_and_optimization.jl index b92a74b61..29f516e11 100644 --- a/test/network_losses/losses_and_optimization.jl +++ b/test/network_losses/losses_and_optimization.jl @@ -9,11 +9,11 @@ const sin_vector = sin.(0:0.01:2π) const dl = DataLoader(reshape(sin_vector, 1, length(sin_vector), 1)) function setup_network(dl::DataLoader{T}) where T - arch = Chain(Dense(1, 5, tanh), ResNetLayer(5, tanh), Dense(5, 1, identity)) + arch = Chain(Dense(dl.input_dim, 5, tanh), ResNetLayer(5, tanh), Dense(5, 1, identity)) NeuralNetwork(arch, CPU(), T) end -function train_network(; n_epochs=5) +function train_network(; n_epochs=10) nn = setup_network(dl) loss = FeedForwardLoss() diff --git a/test/optimizers/optimizer_convergence_tests/psd_optim.jl b/test/optimizers/optimizer_convergence_tests/psd_optim.jl index 3e012f562..6135a83c2 100644 --- a/test/optimizers/optimizer_convergence_tests/psd_optim.jl +++ b/test/optimizers/optimizer_convergence_tests/psd_optim.jl @@ -36,7 +36,7 @@ function svd_test(A, n, train_steps=1000, tol=1e-1; retraction=cayley) err_best = norm(A - U_result*U_result'*A) model = Chain(PSDLayer(2*N, 2*n), PSDLayer(2*n, 2*N)) - ps = initialparameters(model, CPU(), Float64) + ps = NeuralNetwork(model, CPU(), Float64).params o₁ = Optimizer(GradientOptimizer(0.01), ps; retraction = retraction) o₂ = Optimizer(MomentumOptimizer(0.01), ps; retraction = retraction) @@ -56,7 +56,7 @@ function svd_test(A, n, train_steps=1000, tol=1e-1; retraction=cayley) @test norm((err₃ - err_best)/err_best) < tol end -function train_network!(o::Optimizer, model::Chain, ps::NamedTuple, A::AbstractMatrix, train_steps, tol) +function train_network!(o::Optimizer, model::Chain, ps::NeuralNetworkParameters, A::AbstractMatrix, train_steps, tol) error(ps) = norm(A - model(A, ps)) for _ in 1:train_steps diff --git a/test/optimizers/optimizer_convergence_tests/svd_optim.jl b/test/optimizers/optimizer_convergence_tests/svd_optim.jl index 7f20dcfad..e260c40ff 100644 --- a/test/optimizers/optimizer_convergence_tests/svd_optim.jl +++ b/test/optimizers/optimizer_convergence_tests/svd_optim.jl @@ -32,7 +32,7 @@ function svd_test(A, n, train_steps=1000, tol=1e-1; retraction=cayley) err_best = norm(A - U_result*U_result'*A) model = Chain(StiefelLayer(N, n), StiefelLayer(n, N)) - ps = initialparameters(model, CPU(), Float64) + ps = NeuralNetwork(model, CPU(), Float64).params o₁ = Optimizer(GradientOptimizer(0.01), ps; retraction = retraction) o₂ = Optimizer(MomentumOptimizer(0.01), ps; retraction = retraction) @@ -52,7 +52,7 @@ function svd_test(A, n, train_steps=1000, tol=1e-1; retraction=cayley) @test norm((err₃ - err_best)/err_best) < tol end -function train_network!(o::Optimizer, model::Chain, ps::NamedTuple, A::AbstractMatrix, train_steps, tol) +function train_network!(o::Optimizer, model::Chain, ps::NeuralNetworkParameters, A::AbstractMatrix, train_steps, tol) error(ps) = norm(A - model(A, ps)) for _ in 1:train_steps @@ -60,7 +60,7 @@ function train_network!(o::Optimizer, model::Chain, ps::NamedTuple, A::AbstractM λY = GlobalSection(ps) optimization_step!(o, λY, ps, dx) end - ps[1].weight, ps[2].weight, error(ps) + ps.L1.weight, ps.L2.weight, error(ps) end for retraction in (geodesic, cayley) diff --git a/test/optimizers/utils/optimization_step.jl b/test/optimizers/utils/optimization_step.jl index aa9884e59..3e3a8cd78 100644 --- a/test/optimizers/utils/optimization_step.jl +++ b/test/optimizers/utils/optimization_step.jl @@ -1,13 +1,13 @@ using GeometricMachineLearning, Test, LinearAlgebra, KernelAbstractions using AbstractNeuralNetworks: AbstractExplicitLayer -import GeometricMachineLearning: initialparameters +import GeometricMachineLearning: NeuralNetwork import Random Random.seed!(1234) function optimization_step_test(N, n, T) model = Chain(StiefelLayer(N, n), Dense(N, N, tanh)) - ps = initialparameters(model, KernelAbstractions.CPU(), T) + ps = NeuralNetwork(model, KernelAbstractions.CPU(), T).params # gradient dx = (L1 = (weight=rand(Float32, N, n),), L2 = (W=rand(Float32, N, N), b=rand(Float32, N))) m = AdamOptimizer() @@ -18,7 +18,7 @@ function optimization_step_test(N, n, T) λY = GlobalSection(ps) optimization_step!(o, λY, ps, dx) @test typeof(ps[1].weight) <: StiefelManifold - for (layers1, layers2) in zip(ps, ps2) + for (layers1, layers2) in zip(values(ps), values(ps2)) for key in keys(layers1) @test norm(layers1[key] - layers2[key]) > T(1f-6) end diff --git a/test/reduced_system.jl b/test/reduced_system.jl index d7eb39281..ee2f952a1 100644 --- a/test/reduced_system.jl +++ b/test/reduced_system.jl @@ -17,7 +17,7 @@ function set_up_reduced_systems(reduced_dim::Integer, integrator) model1 = PSDArch(dl.input_dim, reduced_dim) # Here the number of decoder blocks is set manually because the default is too big! - model2 = SymplecticAutoencoder(dl.input_dim, reduced_dim; activation = x -> log(1. + exp(x)), n_encoder_layers = 20, n_decoder_blocks = 2) + model2 = SymplecticAutoencoder(dl.input_dim, reduced_dim; activation = x -> log(1. + exp(x)), n_encoder_layers = 20, n_decoder_layers = 10, n_decoder_blocks = 2) nn1 = NeuralNetwork(model1) diff --git a/test/transformer_related/multi_head_attention_stiefel_optim_cache.jl b/test/transformer_related/multi_head_attention_stiefel_optim_cache.jl index b2bae9f38..216ba0c64 100644 --- a/test/transformer_related/multi_head_attention_stiefel_optim_cache.jl +++ b/test/transformer_related/multi_head_attention_stiefel_optim_cache.jl @@ -44,8 +44,8 @@ TODO: """ function test_cache_setups_for_optimizer_for_multihead_attention_layer(T::Type, dim::Int, n_heads::Int) @assert dim % n_heads == 0 - model = MultiHeadAttention(dim, n_heads, Stiefel=true) - ps = initialparameters(model, CPU(), T) + model = Chain(MultiHeadAttention(dim, n_heads, Stiefel=true)) + ps = NeuralNetwork(model, CPU(), T).params o₁ = Optimizer(AdamOptimizer(), ps) o₂ = Optimizer(MomentumOptimizer(), ps) diff --git a/test/transformer_related/multi_head_attention_stiefel_retraction.jl b/test/transformer_related/multi_head_attention_stiefel_retraction.jl index 1f8ef8c60..400ab6478 100644 --- a/test/transformer_related/multi_head_attention_stiefel_retraction.jl +++ b/test/transformer_related/multi_head_attention_stiefel_retraction.jl @@ -32,9 +32,9 @@ check_retraction_cayley(B::MomentumCache) = check_retraction_cayley(B.B) This is a test for that checks if the retractions (geodesic and Cayley for now) map from `StiefelLieAlgHorMatrix` to `StiefelManifold` when used with `MultiHeadAttention`. """ function test_multi_head_attention_retraction(T::Type, dim, n_heads, tol=eps(T), backend=KernelAbstractions.CPU()) - model = MultiHeadAttention(dim, n_heads, Stiefel=true) + model = Chain(MultiHeadAttention(dim, n_heads, Stiefel=true)) - ps = initialparameters(model, backend, T) + ps = NeuralNetwork(model, backend, T).params cache = init_optimizer_cache(MomentumOptimizer(), ps) check_retraction_geodesic(cache) diff --git a/test/transformer_related/multi_head_attention_stiefel_setup.jl b/test/transformer_related/multi_head_attention_stiefel_setup.jl index cd3deaebd..36ad4bac3 100644 --- a/test/transformer_related/multi_head_attention_stiefel_setup.jl +++ b/test/transformer_related/multi_head_attention_stiefel_setup.jl @@ -12,6 +12,7 @@ function check_setup(A::AbstractMatrix{T}, tol=T(10)*eps(T)) where T @test check(A) < tol end check_setup(ps::NamedTuple) = apply_toNT(check_setup, ps) +check_setup(ps::NeuralNetworkParameters) = check_setup(ps.params) @doc raw""" This checks for an arbitrary matrix ``B\in\mathbb{R}^{N\times{}N}`` if ``B\in\mathfrak{g}^\mathrm{hor}``. @@ -27,8 +28,8 @@ check_grad_setup(B::MomentumCache) = check_grad_setup(B.B) Check if `initialparameters` and `init_optimizer_cache` do the right thing for `MultiHeadAttentionLayer`. """ function check_multi_head_attention_stiefel_setup(T::Type, N::Int, n::Int) - model = MultiHeadAttention(N, n, Stiefel=true) - ps = initialparameters(model, KernelAbstractions.CPU(), T) + model = Chain(MultiHeadAttention(N, n, Stiefel=true)) + ps = NeuralNetwork(model, KernelAbstractions.CPU(), T).params check_setup(ps) diff --git a/test/transformer_related/transformer_application.jl b/test/transformer_related/transformer_application.jl index 4ba88f7c3..3a3228390 100644 --- a/test/transformer_related/transformer_application.jl +++ b/test/transformer_related/transformer_application.jl @@ -11,8 +11,8 @@ function transformer_application_test(T, dim, n_heads, L, seq_length=8, batch_si model₁ = Chain(Transformer(dim, n_heads, L, Stiefel=false), ResNetLayer(dim)) model₂ = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNetLayer(dim)) - ps₁ = initialparameters(model₁, KernelAbstractions.CPU(), T) - ps₂ = initialparameters(model₂, KernelAbstractions.CPU(), T) + ps₁ = NeuralNetwork(model₁, KernelAbstractions.CPU(), T).params + ps₂ = NeuralNetwork(model₂, KernelAbstractions.CPU(), T).params input₁ = rand(T, dim, seq_length, batch_size) input₂ = rand(T, dim, seq_length) diff --git a/test/transformer_related/transformer_gradient.jl b/test/transformer_related/transformer_gradient.jl index 85178a5c2..e52be727f 100644 --- a/test/transformer_related/transformer_gradient.jl +++ b/test/transformer_related/transformer_gradient.jl @@ -11,8 +11,8 @@ function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size= model₁ = Chain(Transformer(dim, n_heads, L, Stiefel=false), ResNetLayer(dim)) model₂ = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNetLayer(dim)) - ps₁ = initialparameters(model₁, KernelAbstractions.CPU(), T) - ps₂ = initialparameters(model₂, KernelAbstractions.CPU(), T) + ps₁ = NeuralNetwork(model₁, KernelAbstractions.CPU(), T).params + ps₂ = NeuralNetwork(model₂, KernelAbstractions.CPU(), T).params input₁ = rand(T, dim, seq_length, batch_size) input₂ = rand(T, dim, seq_length, batch_size) @@ -24,10 +24,10 @@ function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size= grad₃ = Zygote.gradient(ps -> loss₂(ps, input₁), ps₂)[1] grad₄ = Zygote.gradient(ps -> loss₂(ps, input₂), ps₂)[1] - @test typeof(grad₁) == typeof(ps₁) - @test typeof(grad₂) == typeof(ps₁) - @test typeof(grad₃) != typeof(ps₂) - @test typeof(grad₄) != typeof(ps₂) + @test typeof(NeuralNetworkParameters(grad₁.params)) == typeof(ps₁) + @test typeof(NeuralNetworkParameters(grad₂.params)) == typeof(ps₁) + @test typeof(NeuralNetworkParameters(grad₃.params)) != typeof(ps₂) + @test typeof(NeuralNetworkParameters(grad₄.params)) != typeof(ps₂) end transformer_gradient_test(Float32, 10, 5, 4) \ No newline at end of file diff --git a/test/transformer_related/transformer_optimizer.jl b/test/transformer_related/transformer_optimizer.jl index de854f66a..d6907d577 100644 --- a/test/transformer_related/transformer_optimizer.jl +++ b/test/transformer_related/transformer_optimizer.jl @@ -11,7 +11,7 @@ function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size= model = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNetLayer(dim)) model = Transformer(dim, n_heads, L, Stiefel=true) - ps = initialparameters(model, KernelAbstractions.CPU(), T) + ps = NeuralNetwork(model, KernelAbstractions.CPU(), T).params input = rand(T, dim, seq_length, batch_size) diff --git a/test/transformer_related/transformer_setup.jl b/test/transformer_related/transformer_setup.jl index 8ac553eae..47164e3f8 100644 --- a/test/transformer_related/transformer_setup.jl +++ b/test/transformer_related/transformer_setup.jl @@ -5,7 +5,7 @@ This function tests the setup of the transformer with Stiefel weights. """ function transformer_setup_test(dim, n_heads, L, T) model = Transformer(dim, n_heads, L, Stiefel=true) - ps = initialparameters(model, KernelAbstractions.CPU(), T) + ps = NeuralNetwork(model, KernelAbstractions.CPU(), T).params @test typeof(ps[1].PQ.head_1) <: StiefelManifold end