Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient issues with kernelmatrix_diag and use ChainRulesCore #208

Merged
merged 50 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
e525614
Use broadcasting instead of map for kerneldiagmatrix
theogf Dec 9, 2020
e56492a
Removed method for transformedkernel
theogf Dec 9, 2020
35a6306
Restored functions and applied suggestions
theogf Dec 14, 2020
25e5efd
Added tests for diagmatrix
theogf Dec 14, 2020
2f85ebc
Put changes to the right file and removed utils_AD.jl
theogf Dec 14, 2020
cae225f
Apply suggestions from code review
theogf Dec 14, 2020
3f16f07
Added colwise and fixed kerneldiagmatrix
theogf Dec 15, 2020
8c0d0a2
Added colwise for RowVecs and ColVecs
theogf Dec 16, 2020
13a10fd
Removed definition relying on Distances.colwise!
theogf Dec 21, 2020
78a2078
Merge branch 'master' into fix_diagmat
theogf Mar 16, 2021
5ca94e7
Readapt to kernelmatrix_diag
theogf Mar 16, 2021
2c60abd
Fixes for Zygote
theogf Mar 16, 2021
9214211
Remove type piracy
theogf Mar 16, 2021
87edbc8
Adding some adjoints (not everything fixed yet)
theogf Mar 17, 2021
f65556b
Fixed adjoint for polynomials
theogf Mar 17, 2021
48e2dcb
Add ChainRulesCore for defining rrule
theogf Mar 17, 2021
6cc803d
Replace broadcast by map
theogf Mar 17, 2021
0e30941
Missing return for style
theogf Mar 17, 2021
61869b1
Fixing ZygoteRules
theogf Mar 22, 2021
06bd4f0
Renamed zygote_adjoints to chainrules
theogf Mar 22, 2021
8e1e516
Apply formatting suggestions
theogf Mar 22, 2021
aaa16de
Added forward rule for Euclidean distance
theogf Mar 22, 2021
52b1ae5
Corrected rules for Row/ColVecs constructors
theogf Mar 22, 2021
4067a42
Added ZygoteRules back for the "map hack"
theogf Mar 22, 2021
641ebee
Corrected the rrules
theogf Mar 22, 2021
13d1e39
Type stable frule
theogf Mar 22, 2021
4675c2f
Corrected tests
theogf Mar 23, 2021
0b97c1a
Adapted the use of Distances.jl
theogf Mar 23, 2021
ad9838e
Added methods to make nn work
theogf Mar 23, 2021
650dc08
Missing kernelmatrix_diag
theogf Mar 23, 2021
1703db1
Formatting suggestions
theogf Mar 23, 2021
e2cd167
Added methods for FBM
theogf Mar 23, 2021
01ffac0
Last fix on Delta
theogf Mar 23, 2021
9bfb6eb
Potential fix for Euclidean
theogf Mar 23, 2021
f3fa4bc
Missing Distances.
theogf Mar 23, 2021
a0c2a64
Wrong file naming
theogf Mar 23, 2021
ff5a66b
Correct formatting
theogf Mar 23, 2021
8157b4c
Better error message
theogf Mar 23, 2021
e6bfdb1
Moar formatting
theogf Mar 23, 2021
db5e7b8
Applied suggestions
theogf Mar 24, 2021
a44a762
Fixed the dims issue with pairwise
theogf Mar 24, 2021
72889dd
Fixed formatting
theogf Mar 24, 2021
25549c1
Missing @thunk
theogf Mar 24, 2021
bbe5c7c
Putting back Composite to Any
theogf Mar 24, 2021
e08dbf4
add @thunk for -delta a
theogf Mar 24, 2021
48bd681
Update src/chainrules.jl
theogf Mar 25, 2021
3298d34
Update KernelFunctions.jl
theogf Mar 25, 2021
0b99771
Apply suggestions from code review
theogf Mar 25, 2021
c26edf3
Update Project.toml
theogf Mar 25, 2021
647862a
Merge branch 'master' into fix_diagmat
theogf Mar 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.8.24"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -17,8 +18,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ChainRulesCore = "0.9"
Compat = "3.7"
Distances = "0.9.1, 0.10"
Distances = "0.10"
Functors = "0.1"
Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10, 1"
Expand Down
8 changes: 5 additions & 3 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ export IndependentMOKernel, LatentFactorMOKernel
export tensor, ⊗

using Compat
using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS, @thunk
using Requires
using Distances, LinearAlgebra
using Functors
using SpecialFunctions: loggamma, besselk, polygamma
using ZygoteRules: @adjoint, pullback
using StatsFuns: logtwo
using ZygoteRules: ZygoteRules
using StatsFuns: logtwo, twoπ
using StatsBase
using TensorCore

Expand Down Expand Up @@ -112,7 +113,8 @@ include(joinpath("mokernels", "moinput.jl"))
include(joinpath("mokernels", "independent.jl"))
include(joinpath("mokernels", "slfm.jl"))

include("zygote_adjoints.jl")
include("chainrules.jl")
include("zygoterules.jl")

include("test_utils.jl")

Expand Down
11 changes: 11 additions & 0 deletions src/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,14 @@ function kernelmatrix!(
K .= _fbm.(_mod(x), _mod(y)', K, κ.h)
return K
end

function kernelmatrix_diag(κ::FBMKernel, x::AbstractVector)
modx = _mod(x)
modxx = colwise(SqEuclidean(), x)
return _fbm.(modx, modx, modxx, κ.h)
end

function kernelmatrix_diag(κ::FBMKernel, x::AbstractVector, y::AbstractVector)
modxy = colwise(SqEuclidean(), x, y)
return _fbm.(_mod(x), _mod(y), modxy, κ.h)
end
4 changes: 4 additions & 0 deletions src/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ function kernelmatrix(κ::GaborKernel, x::AbstractVector, y::AbstractVector)
end

kernelmatrix_diag(κ::GaborKernel, x::AbstractVector) = kernelmatrix_diag(κ.kernel, x)

function kernelmatrix_diag(κ::GaborKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix_diag(κ.kernel, x, y)
end
26 changes: 26 additions & 0 deletions src/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ function kernelmatrix(::NeuralNetworkKernel, x::ColVecs)
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

function kernelmatrix_diag(::NeuralNetworkKernel, x::ColVecs)
x_2 = vec(sum(x.X .* x.X; dims=1))
return asin.(x_2 ./ (x_2 .+ 1))
end

function kernelmatrix_diag(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)
validate_inputs(x, y)
x_2 = vec(sum(x.X .* x.X; dims=1) .+ 1)
y_2 = vec(sum(y.X .* y.X; dims=1) .+ 1)
xy = vec(sum(x.X' .* y.X'; dims=2))
return asin.(xy ./ sqrt.(x_2 .* y_2))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
validate_inputs(x, y)
X_2 = sum(x.X .* x.X; dims=2)
Expand All @@ -65,4 +78,17 @@ function kernelmatrix(::NeuralNetworkKernel, x::RowVecs)
return asin.(XX ./ sqrt.(X_2_1 * X_2_1'))
end

function kernelmatrix_diag(::NeuralNetworkKernel, x::RowVecs)
x_2 = vec(sum(x.X .* x.X; dims=2))
return asin.(x_2 ./ (x_2 .+ 1))
end

function kernelmatrix_diag(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
validate_inputs(x, y)
x_2 = vec(sum(x.X .* x.X; dims=2) .+ 1)
y_2 = vec(sum(y.X .* y.X; dims=2) .+ 1)
xy = vec(sum(x.X .* y.X; dims=2))
return asin.(xy ./ sqrt.(x_2 .* y_2))
end

Base.show(io::IO, ::NeuralNetworkKernel) = print(io, "Neural Network Kernel")
160 changes: 160 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
## Forward Rules

# Note that this is type piracy as the derivative should be NaN for x == y.
function ChainRulesCore.frule(
(_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector
)
Δ = x - y
D = sqrt(sum(abs2, Δ))
if !iszero(D)
Δ ./= D
end
return D, dot(Δ, Δx) - dot(Δ, Δy)
end

## Reverse Rules Delta

function ChainRulesCore.rrule(dist::Delta, x::AbstractVector, y::AbstractVector)
d = dist(x, y)
function evaluate_pullback(::Any)
theogf marked this conversation as resolved.
Show resolved Hide resolved
return NO_FIELDS, Zero(), Zero()
end
return d, evaluate_pullback
end

function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X, Y; dims=dims)
function pairwise_pullback(::AbstractMatrix)
return NO_FIELDS, NO_FIELDS, Zero(), Zero()
end
return P, pairwise_pullback
end

function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X; dims=dims)
function pairwise_pullback(::AbstractMatrix)
return NO_FIELDS, NO_FIELDS, Zero()
end
return P, pairwise_pullback
end

function ChainRulesCore.rrule(
::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix
)
C = Distances.colwise(d, X, Y)
function colwise_pullback(::AbstractVector)
return NO_FIELDS, NO_FIELDS, Zero(), Zero()
end
return C, colwise_pullback
end

## Reverse Rules DotProduct

function ChainRulesCore.rrule(dist::DotProduct, x::AbstractVector, y::AbstractVector)
d = dist(x, y)
function evaluate_pullback(Δ::Any)
return NO_FIELDS, Δ .* y, Δ .* x
end
return d, evaluate_pullback
end

function ChainRulesCore.rrule(
::typeof(Distances.pairwise),
d::DotProduct,
X::AbstractMatrix,
Y::AbstractMatrix;
dims=2,
)
P = Distances.pairwise(d, X, Y; dims=dims)
function pairwise_pullback_cols(Δ::AbstractMatrix)
if dims == 1
return NO_FIELDS, NO_FIELDS, Δ * Y, Δ' * X
else
return NO_FIELDS, NO_FIELDS, Y * Δ', X * Δ
end
end
return P, pairwise_pullback_cols
end

function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X; dims=dims)
function pairwise_pullback_cols(Δ::AbstractMatrix)
if dims == 1
return NO_FIELDS, NO_FIELDS, 2 * Δ * X
else
return NO_FIELDS, NO_FIELDS, 2 * X * Δ
end
end
return P, pairwise_pullback_cols
end

function ChainRulesCore.rrule(
::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix
)
C = Distances.colwise(d, X, Y)
function colwise_pullback(Δ::AbstractVector)
return NO_FIELDS, NO_FIELDS, Δ' .* Y, Δ' .* X
end
return C, colwise_pullback
end

## Reverse Rules Sinus

function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
d = x - y
sind = sinpi.(d)
abs2_sind_r = abs2.(sind) ./ s.r
val = sum(abs2_sind_r)
gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2)
function evaluate_pullback(Δ::Any)
return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx
end
return val, evaluate_pullback
end

## Reverse Rulse SqMahalanobis

function ChainRulesCore.rrule(
dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector
)
d = dist(a, b)
function SqMahalanobis_pullback(Δ::Real)
B_Bᵀ = dist.qmat + transpose(dist.qmat)
a_b = a - b
δa = @thunk((B_Bᵀ * a_b) * Δ)
return (qmat=(a_b * a_b') * Δ,), δa, -δa
end
return d, SqMahalanobis_pullback
end

## Reverse Rules for matrix wrappers

function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)
ColVecs_pullback(Δ::Composite{<:ColVecs}) = (NO_FIELDS, Δ.X)
function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
return error(
"Pullback on AbstractVector{<:AbstractVector}.\n" *
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`",
)
end
return ColVecs(X), ColVecs_pullback
end

function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix)
RowVecs_pullback(Δ::Composite{<:RowVecs}) = (NO_FIELDS, Δ.X)
function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
return error(
"Pullback on AbstractVector{<:AbstractVector}.\n" *
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`",
)
end
return RowVecs(X), RowVecs_pullback
end
5 changes: 3 additions & 2 deletions src/distances/delta.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
struct Delta <: Distances.PreMetric end
# Delta is not following the PreMetric rules since d(x, x) == 1
struct Delta <: Distances.UnionPreMetric end

@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector)
@boundscheck if length(a) != length(b)
Expand All @@ -12,7 +13,7 @@ struct Delta <: Distances.PreMetric end
return a == b
end

Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool

@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
@inline (dist::Delta)(a::Number, b::Number) = a == b
4 changes: 2 additions & 2 deletions src/distances/dotproduct.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct DotProduct <: Distances.PreMetric end
# struct DotProduct <: Distances.UnionSemiMetric end
## DotProduct is not following the PreMetric rules since d(x, x) != 0 and d(x, y) >= 0 for all x, y
struct DotProduct <: Distances.UnionPreMetric end

@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector)
@boundscheck if length(a) != length(b)
Expand Down
31 changes: 31 additions & 0 deletions src/distances/pairwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,34 @@ function pairwise!(
)
return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
end

# Also defines the colwise method for abstractvectors

function colwise(d::PreMetric, x::AbstractVector)
return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition
end

## The following is a hack for DotProduct and Delta to still work
function colwise(d::Distances.UnionPreMetric, x::ColVecs)
return Distances.colwise(d, x.X, x.X)
end

function colwise(d::Distances.UnionPreMetric, x::RowVecs)
return Distances.colwise(d, x.X', x.X')
end

function colwise(d::Distances.UnionPreMetric, x::AbstractVector)
return map(d, x, x)
end

function colwise(d::PreMetric, x::ColVecs, y::ColVecs)
return Distances.colwise(d, x.X, y.X)
end

function colwise(d::PreMetric, x::RowVecs, y::RowVecs)
return Distances.colwise(d, x.X', y.X')
end

function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector)
return map(d, x, y)
end
3 changes: 1 addition & 2 deletions src/distances/sinus.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
struct Sinus{T} <: Distances.SemiMetric
# struct Sinus{T} <: Distances.UnionSemiMetric
struct Sinus{T} <: Distances.UnionSemiMetric
r::Vector{T}
end

Expand Down
2 changes: 1 addition & 1 deletion src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ Base.iterate(k::Kernel, ::Any) = nothing
printshifted(io::IO, o, shift::Int) = print(io, o)

# Fallback implementation of evaluate for `SimpleKernel`s.
(k::SimpleKernel)(x, y) = kappa(k, evaluate(metric(k), x, y))
(k::SimpleKernel)(x, y) = kappa(k, metric(k)(x, y))
4 changes: 4 additions & 0 deletions src/kernels/kernelproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ function kernelmatrix_diag(κ::KernelProduct, x::AbstractVector)
return reduce(hadamard, kernelmatrix_diag(k, x) for k in κ.kernels)
end

function kernelmatrix_diag(κ::KernelProduct, x::AbstractVector, y::AbstractVector)
return reduce(hadamard, kernelmatrix_diag(k, x, y) for k in κ.kernels)
end

function Base.show(io::IO, κ::KernelProduct)
return printshifted(io, κ, 0)
end
Expand Down
4 changes: 4 additions & 0 deletions src/kernels/kernelsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ function kernelmatrix_diag(κ::KernelSum, x::AbstractVector)
return sum(kernelmatrix_diag(k, x) for k in κ.kernels)
end

function kernelmatrix_diag(κ::KernelSum, x::AbstractVector, y::AbstractVector)
return sum(kernelmatrix_diag(k, x, y) for k in κ.kernels)
end

function Base.show(io::IO, κ::KernelSum)
return printshifted(io, κ, 0)
end
Expand Down
5 changes: 5 additions & 0 deletions src/kernels/kerneltensorproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ function kernelmatrix_diag(k::KernelTensorProduct, x::AbstractVector)
return mapreduce(kernelmatrix_diag, hadamard, k.kernels, slices(x))
end

function kernelmatrix_diag(k::KernelTensorProduct, x::AbstractVector, y::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix_diag, hadamard, k.kernels, slices(x), slices(y))
end

Base.show(io::IO, kernel::KernelTensorProduct) = printshifted(io, kernel, 0)

function Base.:(==)(x::KernelTensorProduct, y::KernelTensorProduct)
Expand Down
4 changes: 4 additions & 0 deletions src/kernels/scaledkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ function kernelmatrix_diag(κ::ScaledKernel, x::AbstractVector)
return κ.σ² .* kernelmatrix_diag(κ.kernel, x)
end

function kernelmatrix_diag(κ::ScaledKernel, x::AbstractVector, y::AbstractVector)
return κ.σ² .* kernelmatrix_diag(κ.kernel, x, y)
end

function kernelmatrix!(
K::AbstractMatrix, κ::ScaledKernel, x::AbstractVector, y::AbstractVector
)
Expand Down
Loading