-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resolve type stability of evaluation of KernelSum
#459
Resolve type stability of evaluation of KernelSum
#459
Conversation
Codecov ReportBase: 68.82% // Head: 68.87% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #459 +/- ##
==========================================
+ Coverage 68.82% 68.87% +0.04%
==========================================
Files 52 52
Lines 1344 1346 +2
==========================================
+ Hits 925 927 +2
Misses 419 419
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There seem to be some unrelated failures. |
This is weird. I'm finding that type-stability issues are resolved with the following implementations: function kernelmatrix(κ::KernelSum, x::AbstractVector)
return sum(map(Base.Fix2(kernelmatrix, x), κ.kernels))
end
function kernelmatrix(κ::KernelSum, x::AbstractVector, y::AbstractVector)
return sum(map(k -> kernelmatrix(k, x, y), κ.kernels))
end
function kernelmatrix_diag(κ::KernelSum, x::AbstractVector)
return sum(map(Base.Fix2(kernelmatrix_diag, x), κ.kernels))
end
function kernelmatrix_diag(κ::KernelSum, x::AbstractVector, y::AbstractVector)
return sum(map(k -> kernelmatrix_diag(k, x, y), κ.kernels))
end For some reason the implementation that uses the generator doesn't infer properly, but
I guess this is related to |
Hmm, that'd then be similar to the hack I used in JuliaGaussianProcesses/GPLikelihoods.jl#90, right? |
Indeed! |
I'd rather not use that here because the allocations might increase too much. A loopy version is probably also not so great for AD. |
Agreed -- I mean, |
Anyway, this seems to work on nightly, presumably due to JuliaLang/julia#45789, but unfortunately that hasn't been backported to 1.8. |
Ahh I see. Well I'm happy to wait for 1.9 if it will just fix this -- @simsurace is this causing you unacceptable performance issues, or is it something that you could live with for a few months? |
Hmm I just did another minimal example on nightly and it didn't pass the test. I guess things might be more complicated after all. |
Okay. I'll have a think about this. It might be that there's a straightforward, albeit slightly more verbose, way to implement this that we're not seeing.. |
The following fails on a nightly build on GitHub:
To the extent that this is similar (JET.jl shows type inference problems that look about the same to the above kernel examples), this is still a Julia issue. Could someone here verify this error? |
On a related note, I was thinking about a totally different implementation that would optimize away e.g. multiple uses of the same distance matrix and similar things, using some kind of computational graph. Or is this supposed to be something that the Julia compiler can do on its own? I don't know whether it would help type stability though as long as those kinds of lazy constructions have type inference issues. |
Aha! This seems to work nicely: _sum(f::Tf, x::Tuple) where {Tf} = f(x[1]) + _sum(f, Base.tail(x))
_sum(f::Tf, x::Tuple{Tx}) where {Tf, Tx} = f(x[1])
function kernelmatrix(κ::KernelSum, x::AbstractVector)
return _sum(Base.Fix2(kernelmatrix, x), κ.kernels)
end
function kernelmatrix(κ::KernelSum, x::AbstractVector, y::AbstractVector)
return _sum(k -> kernelmatrix(k, x, y), κ.kernels)
end
function kernelmatrix_diag(κ::KernelSum, x::AbstractVector)
return _sum(Base.Fix2(kernelmatrix_diag, x), κ.kernels)
end
function kernelmatrix_diag(κ::KernelSum, x::AbstractVector, y::AbstractVector)
return _sum(k -> kernelmatrix_diag(k, x, y), κ.kernels)
end It's just a recursive implementation of |
Hmm, how's that going to be with AD? |
Seems to be fine (according to the unit tests). Generally speaking small tuples + recursion are going to be fine with Zygote I believe. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with this subject to CI passing and a bump of the patch version. Once they're both sorted, I'm happy to merge :)
I just did some benchmarks, looks fine to me. using BenchmarkTools
using KernelFunctions
using Test
using Zygote
k = RBFKernel() + RBFKernel() * ExponentialKernel()
@inferred k(0., 1.)
x = rand(100)
@btime kernelmatrix($k, $x)
# before: 207.625 μs (33 allocations: 631.05 KiB)
# after: 212.833 μs (28 allocations: 630.91 KiB)
@btime Zygote.pullback(kernelmatrix, $k, $x)
# before: 331.417 μs (402 allocations: 2.54 MiB)
# after: 335.875 μs (235 allocations: 2.54 MiB)
out, pb = Zygote.pullback(kernelmatrix, k, x)
@btime $pb($out)
# before: 270.625 μs (545 allocations: 1.09 MiB)
# after: 199.208 μs (325 allocations: 1.09 MiB)
k = sum(rand() * RBFKernel() ∘ ScaleTransform(rand()) for _ in 1:20)
@inferred k(0., 1.)
x = rand(100)
@btime kernelmatrix($k, $x)
# before: 1.670 ms (258 allocations: 6.08 MiB)
# after: 1.646 ms (258 allocations: 6.08 MiB)
@btime Zygote.pullback(kernelmatrix, $k, $x)
# before: 2.979 ms (3144 allocations: 18.44 MiB)
# after: 2.861 ms (2901 allocations: 18.41 MiB)
out, pb = Zygote.pullback(kernelmatrix, k, x)
@btime $pb($out)
# before: 1.510 ms (3519 allocations: 6.38 MiB)
# after: 1.483 ms (3223 allocations: 6.35 MiB)
|
KernelSum
(WIP)KernelSum
Co-authored-by: Will Tebbutt <[email protected]>
EDIT: hold on, I now can't reproduce it. I still get type inference errors. I think the |
Ok, seems to be fine after all. There are now 15 additional tests being counted. Closes #458. |
This PR looks like it's basically ready to go, other than syncing it up with master. @simsurace is there anything else that you think needs doing? |
I don't think so. |
Cool. Do you have a few minutes to merge in the master branch to this branch so that we can merge this back into master, or are you happy for me to handle it from here? |
@simsurace have merged in changes from master and re-bumped the patch. Will merge when CI passes. |
Eurgh. Somehow the build has been broken. Going to need to figure out how before merging. |
Okay, I'm pretty sure that the failures aren't related to this PR, so I'm going to merge when CI (other than BaseKernels) passes |
Summary
This is an attempt to patch #458. Since this is probably some deeper Julia type inference issue, this solution is likely to be temporary.
Proposed changes
For now, I just added tests that expose the type instability. For those kernels, and only or those, the issue can be fixed by including an evaluation of
sum
in the module initialization function.What alternatives have you considered?
Re-writing
KernelFunctions.jl/src/kernels/kernelsum.jl
Line 46 in ce7923f
either by using
do
orlet
blocks, or passing aninit
argument tosum
.None of those made the tests pass, but they work if the function is redefined in the REPL, after the module has been loaded.
Breaking changes
None.