Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient issues with kernelmatrix_diag and use ChainRulesCore #208

Merged
merged 50 commits into from
Mar 25, 2021

Conversation

theogf
Copy link
Member

@theogf theogf commented Dec 9, 2020

This PR aims at solving the AD issues coming from kernelmatrix_diag, add tests for it and incidentally solve AD issue for when x==y.
This solves the issue on #203

@willtebbutt
Copy link
Member

willtebbutt commented Dec 9, 2020

This makes me a little nervous. The benefit of calling kerneldiagmatrix from inside kerneldiagmatrix, is that if there's a specialised method in the base kernel for kerneldiagmatrix (e.g. return all zeros), that still gets called.

I would have thought that a better solution would be to require _map(κ.transform, x) to be defined for each transform, so that it happens efficiently.

edit: see e.g. what I do in Stheno.

@willtebbutt
Copy link
Member

willtebbutt commented Dec 9, 2020

Actually, I'm surprised that this problem is happening at all. Looking into it now.

test/utils_AD.jl Outdated Show resolved Hide resolved
@theogf
Copy link
Member Author

theogf commented Dec 14, 2020

@willtebbutt It seems your fix on defining methods for SimpleKernel fixed the issue. I added AD test for kerneldiagmatrix now

test/test_utils.jl Outdated Show resolved Hide resolved
test/test_utils.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

It seems the error is not fixed?

@theogf
Copy link
Member Author

theogf commented Dec 15, 2020

It seems the error is not fixed?

Yep, I tried @willtebbutt suggestion but it does not solve the problem...

src/matrix/kernelmatrix.jl Outdated Show resolved Hide resolved
src/matrix/kernelmatrix.jl Outdated Show resolved Hide resolved
@theogf
Copy link
Member Author

theogf commented Mar 16, 2021

Back to this annoying problem!
So I seem to have found a solution (for Zygote at least), but it has a small touch of piracy :D.
Distances.jl accepts now iterators/generator for pairwise which makes our task easier. However it's not the case for colwise which only takes matrices... I dispatched correctly for ColVecs and RowVecs using x.X and x.X' but then for the generic case of AbstractVector I had to implement our own version -> type piracy... I will anyway open an issue on Distances.jl to see if it is possible to add the option.

@devmotion
Copy link
Member

I guess the correct way would be to define our own colwise function, analogously to pairwise. And then fall back to the definitions in Distances where it is possible.

src/zygote_adjoints.jl Outdated Show resolved Hide resolved
@theogf
Copy link
Member Author

theogf commented Mar 23, 2021

Current fails are just random, it's ready to review/merge

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall! I had a few questions, and I think we should use ChainRulesTestUtils to properly test the ChainRules definitions.

src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/distances/pairwise.jl Outdated Show resolved Hide resolved
test/chainrules.jl Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/zygoterules.jl Outdated Show resolved Hide resolved
test/zygoterules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
test/zygoterules.jl Outdated Show resolved Hide resolved
test/zygoterules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
Co-authored-by: David Widmann <[email protected]>
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
theogf and others added 2 commits March 25, 2021 11:39
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
@theogf
Copy link
Member Author

theogf commented Mar 25, 2021

Should we make a patch version bump followed by a minor version bump where we remove the deprecation warnings?

@willtebbutt
Copy link
Member

Sorry @theogf , which deprecations are you referring to?

@theogf
Copy link
Member Author

theogf commented Mar 25, 2021

kerneldiagmatrix vs kernelmatrix_diag and some others like degree vs d etc...

@willtebbutt
Copy link
Member

Oh I see. Yeah, patch followed by minor release seems reasonable to me.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me 👍

@theogf theogf changed the title Fix issue with gradient from kernelmatrix_diag Fix gradient issues with kernelmatrix_diag and use ChainRulesCore Mar 25, 2021
@theogf theogf merged commit aa2099e into master Mar 25, 2021
@theogf theogf deleted the fix_diagmat branch March 25, 2021 14:12
bmharsha added a commit to bmharsha/KernelFunctions.jl that referenced this pull request Aug 4, 2021
We are encountering following warning during Julia REPL startup if we include `KernelFunctions` in the default SYSIMG, this commit fixes that issue

```
┌ Warning: Error requiring `PDMats` from `KernelFunctions`
│   exception =
│    SystemError: opening file "/home/bmharsha/.julia/packages/KernelFunctions/AxuTC/src/matrix/kernelpdmat.jl": No such file or directory
│    Stacktrace:
│      [1] systemerror(p::String, errno::Int32; extrainfo::Nothing)
│        @ Base ./error.jl:168
│      [2] #systemerror#62
│        @ ./error.jl:167 [inlined]
│      [3] systemerror
│        @ ./error.jl:167 [inlined]
│      [4] open(fname::String; lock::Bool, read::Nothing, write::Nothing, create::Nothing, truncate::Nothing, append::Nothing)
│        @ Base ./iostream.jl:293
│      [5] open
│        @ ./iostream.jl:282 [inlined]
│      [6] open(f::Base.var"JuliaGaussianProcesses#326#327"{String}, args::String; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
│        @ Base ./io.jl:328
│      [7] open
│        @ ./io.jl:328 [inlined]
│      [8] read
│        @ ./io.jl:434 [inlined]
│      [9] _include(mapexpr::Function, mod::Module, _path::String)
│        @ Base ./loading.jl:1166
│     [10] include(mod::Module, _path::String)
│        @ Base ./Base.jl:386
│     [11] include(x::String)
│        @ KernelFunctions ~/.julia/packages/KernelFunctions/AxuTC/src/KernelFunctions.jl:1
│     [12] top-level scope
│        @ ~/.julia/packages/KernelFunctions/AxuTC/src/KernelFunctions.jl:124
│     [13] eval
│        @ ./boot.jl:360 [inlined]
│     [14] eval
│        @ ~/.julia/packages/KernelFunctions/AxuTC/src/KernelFunctions.jl:1 [inlined]
│     [15] (::KernelFunctions.var"JuliaGaussianProcesses#209#215")()
│        @ KernelFunctions ~/.julia/packages/Requires/7Ncym/src/require.jl:99
│     [16] err(f::Any, listener::Module, modname::String)
│        @ Requires ~/.julia/packages/Requires/7Ncym/src/require.jl:47
│     [17] (::KernelFunctions.var"JuliaGaussianProcesses#208#214")()
│        @ KernelFunctions ~/.julia/packages/Requires/7Ncym/src/require.jl:98
│     [18] withpath(f::Any, path::String)
│        @ Requires ~/.julia/packages/Requires/7Ncym/src/require.jl:37
│     [19] (::KernelFunctions.var"JuliaGaussianProcesses#207#213")()
│        @ KernelFunctions ~/.julia/packages/Requires/7Ncym/src/require.jl:97
│     [20] listenpkg(f::Any, pkg::Base.PkgId)
│        @ Requires ~/.julia/packages/Requires/7Ncym/src/require.jl:20
│     [21] macro expansion
│        @ ~/.julia/packages/Requires/7Ncym/src/require.jl:95 [inlined]
│     [22] __init__()
│        @ KernelFunctions ~/.julia/packages/KernelFunctions/AxuTC/src/KernelFunctions.jl:123
└ @ Requires ~/.julia/packages/Requires/7Ncym/src/require.jl:49
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants