Skip to content

Commit

Permalink
Merge pull request #173 from ReactiveBayes/grad_normal_wishart
Browse files Browse the repository at this point in the history
Add gradient of  Normal Wishart
  • Loading branch information
Nimrais authored Jan 14, 2025
2 parents 13249ec + ad1edcc commit 23343fa
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 6 deletions.
7 changes: 6 additions & 1 deletion src/common.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using StatsFuns: logistic
using StatsFuns: softmax, softmax!
using SpecialFunctions: gamma, loggamma
using Distributions

import ForwardDiff

Expand Down Expand Up @@ -45,4 +46,8 @@ function binomial_prod(n, p, x)
end
end

mvdigamma(η, p) = sum(digamma+ (one(d) - d) / 2) for d in 1:p)
mvdigamma(η,p) = sum( digamma+ (one(d) - d)/2) for d=1:p)

abstract type VectorMatrixvariate <: VariateForm end
const VectorMatrixDistribution{S<:ValueSupport} = Distribution{VectorMatrixvariate, S}
const ContinuousMultivariateMatrixvariateDistribution = Distribution{VectorMatrixvariate, Continuous}
21 changes: 20 additions & 1 deletion src/distributions/mv_normal_wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ A multivariate normal-Wishart distribution, where `T` is the element type of the
- `ν::N`: The degrees of freedom of the Wishart distribution
"""
struct MvNormalWishart{T, M <: AbstractArray{T}, V <: AbstractMatrix{T}, K <: Real, N <: Real} <:
ContinuousMatrixDistribution
ContinuousMultivariateMatrixvariateDistribution
μ::M
Ψ::V
κ::K
Expand Down Expand Up @@ -196,6 +196,25 @@ getlogpartition(::NaturalParametersSpace, ::Type{MvNormalWishart}) = (η) -> beg
return (term1 + term2 + term3 + term4) + (d / 2)log2π
end

getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalWishart}) = (η) -> begin
η1, η2, η3, η4 = unpack_parameters(MvNormalWishart, η)
d = length(η1)
const1 = -(d+2η4)/2
kronecker = kron(η1, η1')
veckronecker = vec(kronecker)

const2 = cholinv(-2η2 + kronecker/(2η3))
vconst2 = vec(const2)
kronright = kron(Eye(d), η1) / (2η3)
kronleft = kron(η1, Eye(d)) / (2η3)
dη2 = -2*const1*const2
dη1 = const1*(kronright + kronleft)'*vconst2
dη3 = (-d/(2η3)) - const1*dot(vconst2,veckronecker/(2η3^2))
dη4 = -logdet(-2η2 + kronecker/(2η3)) + d*log(2) + mvdigamma((d + 2 * η4) * (1 / 2),d)

return vcat(dη1,vec(dη2),dη3,dη4)
end

getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalWishart}) =
(η) -> begin
η1, η2, η3, η4 = unpack_parameters(MvNormalWishart, η)
Expand Down
2 changes: 1 addition & 1 deletion src/exponential_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ check_logpdf(ef::ExponentialFamilyDistribution, x) = check_logpdf(variate_form(t
check_logpdf(::Type{Univariate}, ::Type{<:Number}, ::Type{<:Number}, ef, x) = (PointBasedLogpdfCall(), x)
check_logpdf(::Type{Multivariate}, ::Type{<:AbstractVector}, ::Type{<:Number}, ef, x) = (PointBasedLogpdfCall(), x)
check_logpdf(::Type{Matrixvariate}, ::Type{<:AbstractMatrix}, ::Type{<:Number}, ef, x) = (PointBasedLogpdfCall(), x)
check_logpdf(::Type{VectorMatrixvariate}, ::Type{<:AbstractVector}, ::Type{<:Tuple}, ef, x) = (PointBasedLogpdfCall(), x)

function _vlogpdf(ef, container)
_logpartition = logpartition(ef)
Expand All @@ -688,7 +689,6 @@ check_logpdf(::Type{Univariate}, ::Type{<:AbstractVector}, ::Type{<:Number}, ef,
check_logpdf(::Type{Multivariate}, ::Type{<:AbstractVector}, ::Type{<:AbstractVector}, ef, container) = (MapBasedLogpdfCall(), container)
check_logpdf(::Type{Multivariate}, ::Type{<:AbstractMatrix}, ::Type{<:Number}, ef, container) = (MapBasedLogpdfCall(), eachcol(container))
check_logpdf(::Type{Matrixvariate}, ::Type{<:AbstractVector}, ::Type{<:AbstractMatrix}, ef, container) = (MapBasedLogpdfCall(), container)

"""
pdf(ef::ExponentialFamilyDistribution, x)
Expand Down
4 changes: 1 addition & 3 deletions test/distributions/mv_normal_wishart_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ end
option_assume_no_allocations = false,
test_basic_functions = false,
test_fisherinformation_against_hessian = false,
test_fisherinformation_against_jacobian = false,
test_gradlogpartition_properties = false,
test_plogpdf_interface = false
test_fisherinformation_against_jacobian = false
)

run_test_basic_functions(d; assume_no_allocations = false, test_samples_logpdf = false)
Expand Down

0 comments on commit 23343fa

Please sign in to comment.