Skip to content

Commit

Permalink
Use ParamHandling's positive_definite (#40)
Browse files Browse the repository at this point in the history
* Use Paramhandling's `positive_definite`

* Cheeky README update

* Use a less disturbing initialisation

* Update README.md

Co-authored-by: st-- <[email protected]>

Co-authored-by: st-- <[email protected]>
  • Loading branch information
rossviljoen and st-- authored Aug 31, 2021
1 parent 86a55f6 commit 7c5c5c7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 52 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SparseGPs

[![CI](https://github.com/rossviljoen/SparseGPs.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/rossviljoen/SparseGPs.jl/actions/workflows/CI.yml)
[![Codecov](https://codecov.io/gh/rossviljoen/SparseGPs.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/rossviljoen/SparseGPs.jl)
[![Docs](https://img.shields.io/badge/docs-dev-blue.svg)](https://JuliaGaussianProcesses.github.io/SparseGPs.jl/dev)
[![CI](https://github.com/JuliaGaussianProcesses/SparseGPs.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/JuliaGaussianProcesses/SparseGPs.jl/actions/workflows/CI.yml)
[![Codecov](https://codecov.io/gh/JuliaGaussianProcesses/SparseGPs.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaGaussianProcesses/SparseGPs.jl)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
47 changes: 23 additions & 24 deletions examples/b-classification/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[ArrayInterface]]
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "cdb00a6fb50762255021e5571cf95df3e1797a51"
git-tree-sha1 = "baf4ef9082070477046bd98306952292bfcb0af9"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.23"
version = "3.1.25"

[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
Expand All @@ -52,9 +52,9 @@ version = "1.0.6+5"

[[CPUSummary]]
deps = ["Hwloc", "IfElse", "Static"]
git-tree-sha1 = "147bcca99e098c0da48d7d9e108210704138f0f9"
git-tree-sha1 = "ed720e2622820bf584d4ad90e6fcb93d95170b44"
uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
version = "0.1.2"
version = "0.1.3"

[[Cairo_jll]]
deps = ["Artifacts", "Bzip2_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
Expand All @@ -76,9 +76,9 @@ version = "0.10.13"

[[CloseOpenIntervals]]
deps = ["ArrayInterface", "Static"]
git-tree-sha1 = "4fcacb5811c9e4eb6f9adde4afc0e9c4a7a92f5a"
git-tree-sha1 = "ce9c0d07ed6e1a4fecd2df6ace144cbd29ba6f37"
uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9"
version = "0.1.1"
version = "0.1.2"

[[ColorSchemes]]
deps = ["ColorTypes", "Colors", "FixedPointNumbers", "Random"]
Expand Down Expand Up @@ -277,10 +277,9 @@ uuid = "559328eb-81f9-559d-9380-de523a88c83c"
version = "1.0.10+0"

[[Functors]]
deps = ["MacroTools"]
git-tree-sha1 = "4cd9e70bf8fce05114598b663ad79dfe9ae432b3"
git-tree-sha1 = "39007773fd6097164ab537f78d3ac78ad2b8b695"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.2.3"
version = "0.2.4"

[[Future]]
deps = ["Random"]
Expand Down Expand Up @@ -689,10 +688,10 @@ uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.1"

[[ParameterHandling]]
deps = ["Bijectors", "Compat", "IterTools", "LinearAlgebra", "SparseArrays", "Test"]
git-tree-sha1 = "d7a3c24820467785e8506f3276697c948e9a8bb5"
deps = ["Bijectors", "ChainRulesCore", "Compat", "IterTools", "LinearAlgebra", "SparseArrays", "Test"]
git-tree-sha1 = "532a3d7851ee7cd06ee6c2b3d98a333fe79af076"
uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5"
version = "0.3.3"
version = "0.3.5"

[[Parameters]]
deps = ["OrderedCollections", "UnPack"]
Expand Down Expand Up @@ -724,9 +723,9 @@ version = "2.0.1"

[[PlotUtils]]
deps = ["ColorSchemes", "Colors", "Dates", "Printf", "Random", "Reexport", "Statistics"]
git-tree-sha1 = "501c20a63a34ac1d015d5304da0e645f42d91c9f"
git-tree-sha1 = "c67334c786157d6ef091ce622b365d3d60b1e2c4"
uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043"
version = "1.0.11"
version = "1.0.12"

[[Plots]]
deps = ["Base64", "Contour", "Dates", "FFMPEG", "FixedPointNumbers", "GR", "GeometryBasics", "JSON", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "PlotThemes", "PlotUtils", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "UUIDs"]
Expand Down Expand Up @@ -783,15 +782,15 @@ version = "1.1.2"

[[RecipesPipeline]]
deps = ["Dates", "NaNMath", "PlotUtils", "RecipesBase"]
git-tree-sha1 = "2a7a2469ed5d94a98dea0e85c46fa653d76be0cd"
git-tree-sha1 = "32efa73dece357e9c834cae8af00265752c80061"
uuid = "01d81517-befc-4cb6-b9ec-a95719d0359c"
version = "0.3.4"
version = "0.3.5"

[[RecursiveArrayTools]]
deps = ["ArrayInterface", "ChainRulesCore", "DocStringExtensions", "LinearAlgebra", "RecipesBase", "Requires", "StaticArrays", "Statistics", "ZygoteRules"]
git-tree-sha1 = "82efc2429a2b2e72daf2322dbdf5fc60df6dc51f"
git-tree-sha1 = "00bede2eb099dcc1ddc3f9ec02180c326b420ee2"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
version = "2.17.1"
version = "2.17.2"

[[RecursiveFactorization]]
deps = ["LinearAlgebra", "LoopVectorization", "Polyester", "StrideArraysCore", "TriangularSolve"]
Expand All @@ -800,9 +799,9 @@ uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
version = "0.2.2"

[[Reexport]]
git-tree-sha1 = "5f6c21241f0f655da3952fd60aa18477cf96c220"
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.1.0"
version = "1.2.2"

[[Requires]]
deps = ["UUIDs"]
Expand Down Expand Up @@ -928,9 +927,9 @@ version = "0.1.18"

[[StructArrays]]
deps = ["Adapt", "DataAPI", "StaticArrays", "Tables"]
git-tree-sha1 = "000e168f5cc9aded17b6999a560b7c11dda69095"
git-tree-sha1 = "1700b86ad59348c0f9f68ddc95117071f947072d"
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
version = "0.6.0"
version = "0.6.1"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
Expand Down Expand Up @@ -1003,9 +1002,9 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[VectorizationBase]]
deps = ["ArrayInterface", "CPUSummary", "HostCPUFeatures", "Hwloc", "IfElse", "Libdl", "LinearAlgebra", "Static"]
git-tree-sha1 = "0e940546f8ad51f53966c866db14ff9b58be24e0"
git-tree-sha1 = "5e6e23728d6c8d26d2826f6cb2cd21892a958a43"
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
version = "0.20.34"
version = "0.20.38"

[[Wayland_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"]
Expand Down
27 changes: 1 addition & 26 deletions examples/b-classification/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,39 +81,14 @@ plot!(x_true, mean.(lgp.lik.(f_true)); seriescolor="red", label="True mean")
# [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl)
# readme.

# First, we need to define a quick and dirty positive definite matrix type for
# ParameterHandling.jl - this code can safely be ignored.

struct PDMatrix{TA}
A::TA
end

function pdmatrix(A::AbstractMatrix)
return PDMatrix(A)
end

function ParameterHandling.value(P::PDMatrix)
A = copy(P.A)
return A'A
end

function ParameterHandling.flatten(::Type{T}, P::PDMatrix) where {T}
v, unflatten_to_Array = flatten(T, P.A)
function unflatten_PDmatrix(v_new::Vector{T})
A = unflatten_to_Array(v_new)
return PDMatrix(A)
end
return v, unflatten_PDmatrix
end;

# Initialise the parameters

M = 15 # number of inducing points
raw_initial_params = (
k=(var=positive(rand()), precision=positive(rand())),
z=bounded.(range(0.1, 5.9; length=M), 0.0, 6.0), # constrain z to simplify optimisation
m=zeros(M),
A=pdmatrix(4 * Matrix{Float64}(I, M, M)),
A=positive_definite(Matrix{Float64}(I, M, M)),
);

# `flatten` takes the `NamedTuple` of parameters and returns a flat vector of
Expand Down

0 comments on commit 7c5c5c7

Please sign in to comment.