Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve type stability of evaluation of KernelSum #459

Merged
merged 12 commits into from
Sep 26, 2022
Merged

Resolve type stability of evaluation of KernelSum #459

merged 12 commits into from
Sep 26, 2022

Conversation

simsurace
Copy link
Member

@simsurace simsurace commented Jun 23, 2022

Summary
This is an attempt to patch #458. Since this is probably some deeper Julia type inference issue, this solution is likely to be temporary.

Proposed changes
For now, I just added tests that expose the type instability. For those kernels, and only or those, the issue can be fixed by including an evaluation of sum in the module initialization function.

What alternatives have you considered?
Re-writing

::KernelSum)(x, y) = sum(k(x, y) for k in κ.kernels)

either by using do or let blocks, or passing an init argument to sum.
None of those made the tests pass, but they work if the function is redefined in the REPL, after the module has been loaded.

Breaking changes
None.

@simsurace simsurace marked this pull request as draft June 23, 2022 18:04
@codecov
Copy link

codecov bot commented Jun 23, 2022

Codecov Report

Base: 68.82% // Head: 68.87% // Increases project coverage by +0.04% 🎉

Coverage data is based on head (db9cfd9) compared to base (4ce3e87).
Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #459      +/-   ##
==========================================
+ Coverage   68.82%   68.87%   +0.04%     
==========================================
  Files          52       52              
  Lines        1344     1346       +2     
==========================================
+ Hits          925      927       +2     
  Misses        419      419              
Impacted Files Coverage Δ
src/kernels/kernelsum.jl 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

src/KernelFunctions.jl Outdated Show resolved Hide resolved
@simsurace
Copy link
Member Author

There seem to be some unrelated failures.

@willtebbutt
Copy link
Member

This is weird. I'm finding that type-stability issues are resolved with the following implementations:

function kernelmatrix::KernelSum, x::AbstractVector)
    return sum(map(Base.Fix2(kernelmatrix, x), κ.kernels))
end

function kernelmatrix::KernelSum, x::AbstractVector, y::AbstractVector)
    return sum(map(k -> kernelmatrix(k, x, y), κ.kernels))
end

function kernelmatrix_diag::KernelSum, x::AbstractVector)
    return sum(map(Base.Fix2(kernelmatrix_diag, x), κ.kernels))
end

function kernelmatrix_diag::KernelSum, x::AbstractVector, y::AbstractVector)
    return sum(map(k -> kernelmatrix_diag(k, x, y), κ.kernels))
end

For some reason the implementation that uses the generator doesn't infer properly, but suming over the output of map seems to be fine. This is annoying because it doesn't work properly.

mapreduce also doesn't infer

I guess this is related to map being heavily optimised for Tuples, but perhaps mapreduce and whatever generator is produced in the current implementation aren't?

@simsurace
Copy link
Member Author

Hmm, that'd then be similar to the hack I used in JuliaGaussianProcesses/GPLikelihoods.jl#90, right?

@willtebbutt
Copy link
Member

Indeed!

@simsurace
Copy link
Member Author

I'd rather not use that here because the allocations might increase too much. A loopy version is probably also not so great for AD.

@willtebbutt
Copy link
Member

I'd rather not use that here because the allocations might increase too much. A loopy version is probably also not so great for AD.

Agreed -- I mean, map ought to be completely fine for AD. It would only increase the allocations when outside of the context of reverse-mode AD when the total amount of memory allocated is less of an issue anyway. Just seems weird that there isn't a decent type-stable version of mapreduce for tuples anywhere in the ecosystem.

@simsurace
Copy link
Member Author

simsurace commented Aug 30, 2022

Anyway, this seems to work on nightly, presumably due to JuliaLang/julia#45789, but unfortunately that hasn't been backported to 1.8.

@willtebbutt
Copy link
Member

Ahh I see. Well I'm happy to wait for 1.9 if it will just fix this -- @simsurace is this causing you unacceptable performance issues, or is it something that you could live with for a few months?

@simsurace
Copy link
Member Author

Hmm I just did another minimal example on nightly and it didn't pass the test. I guess things might be more complicated after all.
I don't really know how big of a performance issue this is. But hard-coding my kernel in terms of a custom type, where the sum is also hard-coded gave a moderate improvement and made the loss function fully type-stable.

@willtebbutt
Copy link
Member

Okay. I'll have a think about this. It might be that there's a straightforward, albeit slightly more verbose, way to implement this that we're not seeing..

@simsurace
Copy link
Member Author

The following fails on a nightly build on GitHub:

using Test

struct FunctionSum{Tf}
    functions::Tf
end

(F::FunctionSum)(x) = sum(f -> f(x), F.functions)

F = FunctionSum((x -> sqrt(x), FunctionSum((x -> x^2, x -> x^3))))
@inferred F(0.1)

To the extent that this is similar (JET.jl shows type inference problems that look about the same to the above kernel examples), this is still a Julia issue. Could someone here verify this error?

@simsurace
Copy link
Member Author

Okay. I'll have a think about this. It might be that there's a straightforward, albeit slightly more verbose, way to implement this that we're not seeing..

On a related note, I was thinking about a totally different implementation that would optimize away e.g. multiple uses of the same distance matrix and similar things, using some kind of computational graph. Or is this supposed to be something that the Julia compiler can do on its own? I don't know whether it would help type stability though as long as those kinds of lazy constructions have type inference issues.

@willtebbutt
Copy link
Member

willtebbutt commented Aug 30, 2022

Aha! This seems to work nicely:

_sum(f::Tf, x::Tuple) where {Tf} = f(x[1]) + _sum(f, Base.tail(x))

_sum(f::Tf, x::Tuple{Tx}) where {Tf, Tx} = f(x[1])

function kernelmatrix::KernelSum, x::AbstractVector)
    return _sum(Base.Fix2(kernelmatrix, x), κ.kernels)
end

function kernelmatrix::KernelSum, x::AbstractVector, y::AbstractVector)
    return _sum(k -> kernelmatrix(k, x, y), κ.kernels)
end

function kernelmatrix_diag::KernelSum, x::AbstractVector)
    return _sum(Base.Fix2(kernelmatrix_diag, x), κ.kernels)
end

function kernelmatrix_diag::KernelSum, x::AbstractVector, y::AbstractVector)
    return _sum(k -> kernelmatrix_diag(k, x, y), κ.kernels)
end

It's just a recursive implementation of sum(f, x) for the case in which x is a Tuple.

@simsurace
Copy link
Member Author

Hmm, how's that going to be with AD?

@willtebbutt
Copy link
Member

Seems to be fine (according to the unit tests). Generally speaking small tuples + recursion are going to be fine with Zygote I believe.

test/kernels/kernelsum.jl Outdated Show resolved Hide resolved
src/kernels/kernelsum.jl Outdated Show resolved Hide resolved
Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with this subject to CI passing and a bump of the patch version. Once they're both sorted, I'm happy to merge :)

test/kernels/kernelsum.jl Show resolved Hide resolved
@simsurace
Copy link
Member Author

simsurace commented Aug 30, 2022

I just did some benchmarks, looks fine to me.

using BenchmarkTools
using KernelFunctions
using Test
using Zygote

k = RBFKernel() + RBFKernel() * ExponentialKernel()

@inferred k(0., 1.)

x = rand(100)
@btime kernelmatrix($k, $x)
# before: 207.625 μs (33 allocations: 631.05 KiB)
# after:  212.833 μs (28 allocations: 630.91 KiB)
@btime Zygote.pullback(kernelmatrix, $k, $x)
# before: 331.417 μs (402 allocations: 2.54 MiB)
# after:  335.875 μs (235 allocations: 2.54 MiB)
out, pb = Zygote.pullback(kernelmatrix, k, x)
@btime $pb($out)
# before: 270.625 μs (545 allocations: 1.09 MiB)
# after:  199.208 μs (325 allocations: 1.09 MiB)




k = sum(rand() * RBFKernel()  ScaleTransform(rand()) for _ in 1:20)

@inferred k(0., 1.)

x = rand(100)
@btime kernelmatrix($k, $x)
# before: 1.670 ms (258 allocations: 6.08 MiB)
# after:  1.646 ms (258 allocations: 6.08 MiB)
@btime Zygote.pullback(kernelmatrix, $k, $x)
# before: 2.979 ms (3144 allocations: 18.44 MiB)
# after:  2.861 ms (2901 allocations: 18.41 MiB)
out, pb = Zygote.pullback(kernelmatrix, k, x)
@btime $pb($out)
# before: 1.510 ms (3519 allocations: 6.38 MiB)
# after:  1.483 ms (3223 allocations: 6.35 MiB)

@simsurace simsurace marked this pull request as ready for review August 30, 2022 20:06
@simsurace simsurace requested a review from devmotion August 30, 2022 20:06
@simsurace simsurace changed the title Resolve type stability of evaluation of KernelSum (WIP) Resolve type stability of evaluation of KernelSum Aug 30, 2022
Co-authored-by: Will Tebbutt <[email protected]>
@simsurace
Copy link
Member Author

simsurace commented Aug 30, 2022

EDIT: hold on, I now can't reproduce it. I still get type inference errors. I think the check_type_stability function is being compiled away.

@simsurace
Copy link
Member Author

simsurace commented Aug 31, 2022

Ok, seems to be fine after all. There are now 15 additional tests being counted. Closes #458.

@willtebbutt
Copy link
Member

This PR looks like it's basically ready to go, other than syncing it up with master. @simsurace is there anything else that you think needs doing?

@simsurace
Copy link
Member Author

I don't think so.

@willtebbutt
Copy link
Member

Cool. Do you have a few minutes to merge in the master branch to this branch so that we can merge this back into master, or are you happy for me to handle it from here?

@willtebbutt
Copy link
Member

@simsurace have merged in changes from master and re-bumped the patch. Will merge when CI passes.

@willtebbutt
Copy link
Member

Eurgh. Somehow the build has been broken. Going to need to figure out how before merging.

@willtebbutt
Copy link
Member

Okay, I'm pretty sure that the failures aren't related to this PR, so I'm going to merge when CI (other than BaseKernels) passes

@willtebbutt willtebbutt merged commit e42d89d into JuliaGaussianProcesses:master Sep 26, 2022
@simsurace simsurace deleted the ss/kernelsum_typestab branch October 5, 2022 08:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants