Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster logpdf implementation for container based inputs #147

Merged
merged 14 commits into from
Dec 4, 2023
Merged

Conversation

Nimrais
Copy link
Member

@Nimrais Nimrais commented Nov 23, 2023

In this issue it was shown that logpdf evaluation inside the ExponentialFamily.jl is far slower then it theoretically could be.

This PR aims to partially resolve this issue trough two following ideas:
1. Evaluate logpartition only once for a container of points.
2. Use pack_parameters instead of flatten_parameters, for multivariate distributions.

To benchmark logpdf I am using the following script:

using Pkg
using Revise
using ExponentialFamily
using StableRNGs
using Distributions
using LinearAlgebra
using BenchmarkTools
  
function generate_random_normal()
	rng = StableRNG(42)
	n = 4
	golden_μ = randn(rng, n)
	L = randn(rng, n, n)
	golden_Σ = L * L' + Matrix{Float64}(I, n, n)
	return golden_normal = Distributions.MvNormal(golden_μ, golden_Σ);
end

dists = [Gamma(), generate_random_normal()]
efs = map((d) -> convert(ExponentialFamily.ExponentialFamilyDistribution, d), dists)

liniarize(samples::AbstractMatrix) = eachcol(samples)
liniarize(samples::Vector) = samples

function benchmark_old_logpdf(d)
	rng = StableRNG(42)
	samples = liniarize(rand(rng, d, 1000))
	@benchmark map((s) -> logpdf($d, s), $samples)
end

benchmark_old_logpdf(efs[1])
benchmark_old_logpdf(efs[2])

function benchmark_new_logpdf(d)
	rng = StableRNG(42)
	samples = rand(rng, d, 1000)
	@benchmark logpdf($d, $samples)
end

benchmark_new_logpdf(efs[1])
benchmark_new_logpdf(efs[2])

Old evaluation strategy

There are two outputs the speed of logpdf evaluation without naïvly of the logpartition (benchmark_old_logpdf):

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 51.875 μs … 2.353 ms ┊ GC (min … max): 0.00% … 95.82%
Time (median): 53.083 μs ┊ GC (median): 0.00%
Time (mean ± σ): 61.119 μs ± 121.205 μs ┊ GC (mean ± σ): 11.50% ± 5.63%

▄▅██▇▅▄▃▃▃▂▃▂▂▂▁▁▁▁ ▁▁ ▂
███████████████████▇████▇▇▇███▇▇▇▇▇▆▅▅▄▆▅▅▃▄▄▄▄▄▄▄▄▁▄▁▃▄▅▅▄▄ █
51.9 μs Histogram: log(frequency) by time 72 μs <

Memory estimate: 164.19 KiB, allocs estimate: 2001.

BenchmarkTools.Trial: 9754 samples with 1 evaluation.
Range (min … max): 449.666 μs … 1.863 ms ┊ GC (min … max): 0.00% … 68.39%
Time (median): 481.959 μs ┊ GC (median): 0.00%
Time (mean ± σ): 511.372 μs ± 172.085 μs ┊ GC (mean ± σ): 4.95% ± 10.20%

▅█▆▄▃▁ ▁
███████▇▆▄▄▃▁▁▁▁▃▁▁▃▁▁▁▁▁▁▁▁▁▁▃▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▄▇██ █
450 μs Histogram: log(frequency) by time 1.63 ms <

Memory estimate: 1007.94 KiB, allocs estimate: 5001.

New evaluation strategy

There are two outputs the speed of logpdf evaluation with only one logpartition evaluation (benchmark_new_logpdf):

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 43.750 μs … 2.385 ms ┊ GC (min … max): 0.00% … 95.76%
Time (median): 45.167 μs ┊ GC (median): 0.00%
Time (mean ± σ): 53.773 μs ± 135.426 μs ┊ GC (mean ± σ): 14.61% ± 5.67%

▂▄▆▇█▆▃▁ ▁▁ ▁
█████████▇▇▇▆▇████▇▆▆▆▅▄▅▇▆▇▇▇▆▇▆▆▆▆▇▆▇▆▆▆▅▅▅▅▆▅▅▆▄▅▅▅▄▄▅▄▄▄ █
43.8 μs Histogram: log(frequency) by time 59.9 μs <

Memory estimate: 164.19 KiB, allocs estimate: 2001.

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 109.875 μs … 1.425 ms ┊ GC (min … max): 0.00% … 89.07%
Time (median): 117.541 μs ┊ GC (median): 0.00%
Time (mean ± σ): 134.204 μs ± 135.941 μs ┊ GC (mean ± σ): 11.87% ± 10.49%

█▂ ▁
███▄▄▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ █
110 μs Histogram: log(frequency) by time 1.25 ms <

Memory estimate: 633.31 KiB, allocs estimate: 3003.

Copy link

codecov bot commented Nov 23, 2023

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

Comparison is base (e75c73e) 79.51% compared to head (784b7e0) 79.77%.
Report is 1 commits behind head on main.

Files Patch % Lines
src/distributions/normal_gamma.jl 75.00% 1 Missing ⚠️
src/exponential_family.jl 97.50% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #147      +/-   ##
==========================================
+ Coverage   79.51%   79.77%   +0.25%     
==========================================
  Files          39       39              
  Lines        2841     2887      +46     
==========================================
+ Hits         2259     2303      +44     
- Misses        582      584       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Nimrais
Copy link
Member Author

Nimrais commented Nov 23, 2023

@ismailsenoz A new evaluation strategy has not been implemented for MvNormalWishart, and I am a bit confused about it: its samples are not matrices, but its variate form is Matrixvariate. I think someone could do it in a separate PR if there is a need to do the same for this distribution.

@Nimrais
Copy link
Member Author

Nimrais commented Nov 23, 2023

@bvdmitri We have two new traits objects in this PR: MapBasedLogpdfCall and PointBasedLogpdfCall. I can imagine that their purpose can be not completely clear from the first glance so I wrote docstrings for them. But I am not sure that they should go into documentation because these two structures define the behavior of the undocumented method _logpdf. What should I do in this situation?

Comment on lines 275 to 295

function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x)
# TODO: Think of what to do with this assert
@assert insupport(ef, x)
_logpartition = logpartition(ef)
return _logpdf(ef, x, _logpartition)
end

function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x, logpartition)
# TODO: Think of what to do with this assert
@assert insupport(ef, x)
η = getnaturalparameters(ef)
# Use `_` to avoid name collisions with the actual functions
_statistics = sufficientstatistics(ef, x)
_basemeasure = basemeasure(ef, x)
return log(_basemeasure) + dot(η, flatten_parameters(MvNormalWishart, _statistics)) - logpartition
end

function _pdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x)
exp(_logpdf(ef, x))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these methods for? They duplicate the generic function

@bvdmitri
Copy link
Member

You can either

  • Use comments instead of the docstrings
  • Document the change

test/runtests.jl Outdated Show resolved Hide resolved
Comment on lines 275 to 295

function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x)
# TODO: Think of what to do with this assert
@assert insupport(ef, x)
_logpartition = logpartition(ef)
return _logpdf(ef, x, _logpartition)
end

function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x, logpartition)
# TODO: Think of what to do with this assert
@assert insupport(ef, x)
η = getnaturalparameters(ef)
# Use `_` to avoid name collisions with the actual functions
_statistics = sufficientstatistics(ef, x)
_basemeasure = basemeasure(ef, x)
return log(_basemeasure) + dot(η, flatten_parameters(MvNormalWishart, _statistics)) - logpartition
end

function _pdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x)
exp(_logpdf(ef, x))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x)
# TODO: Think of what to do with this assert
@assert insupport(ef, x)
_logpartition = logpartition(ef)
return _logpdf(ef, x, _logpartition)
end
function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x, logpartition)
# TODO: Think of what to do with this assert
@assert insupport(ef, x)
η = getnaturalparameters(ef)
# Use `_` to avoid name collisions with the actual functions
_statistics = sufficientstatistics(ef, x)
_basemeasure = basemeasure(ef, x)
return log(_basemeasure) + dot(η, flatten_parameters(MvNormalWishart, _statistics)) - logpartition
end
function _pdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x)
exp(_logpdf(ef, x))
end

Copy link
Member Author

@Nimrais Nimrais Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can explain my reasoning for this issue.

My generic strategy is failing for this particular distribution: MvNormalWishart is not a matrix variate distribution, yet it possesses this variate type. I am uncertain why it has this variate type, so to avoid altering this I decided to implement a specific realisation for this distribution.

Copy link
Member

@bvdmitri bvdmitri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost everything looks very good, please address minor comments 👍

@bvdmitri bvdmitri changed the title Fast logpdf for exponetial family Faster logpdf implementation for container based inputs Nov 24, 2023
@ismailsenoz
Copy link
Collaborator

@ismailsenoz A new evaluation strategy has not been implemented for MvNormalWishart, and I am a bit confused about it: its samples are not matrices, but its variate form is Matrixvariate. I think someone could do it in a separate PR if there is a need to do the same for this distribution.

Indeed Normal-Wishart distribution is mix-variate and referring to it as Matrixvariate is not ideal. We should consider it in a seperate PR. I am ok with the new strategy.

@bvdmitri
Copy link
Member

bvdmitri commented Dec 4, 2023

I like the idea of making it a mix-variate type (we can do it with abstract types and tuples). But indeed it probably should be a separate PR.

Regarding the failing documentation build see my previous comment

@Nimrais
Copy link
Member Author

Nimrais commented Dec 4, 2023

@bvdmitri I resolved the issue with the documentation.

There are several solutions how the MvNormalWishart issue can be resolved, but let's document our decision somewhere and after I can implement it in a separate PR.

@Nimrais Nimrais requested a review from bvdmitri December 4, 2023 14:13
@Nimrais Nimrais self-assigned this Dec 4, 2023
@Nimrais Nimrais added the enhancement New feature or request label Dec 4, 2023
Copy link
Member

@bvdmitri bvdmitri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks, @Nimrais !

@bvdmitri bvdmitri merged commit 7348a2c into main Dec 4, 2023
4 of 6 checks passed
@bvdmitri bvdmitri deleted the fast-logpdf branch December 4, 2023 16:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants