-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster logpdf implementation for container based inputs #147
Conversation
… or pack_parameters
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #147 +/- ##
==========================================
+ Coverage 79.51% 79.77% +0.25%
==========================================
Files 39 39
Lines 2841 2887 +46
==========================================
+ Hits 2259 2303 +44
- Misses 582 584 +2 ☔ View full report in Codecov by Sentry. |
@ismailsenoz A new evaluation strategy has not been implemented for MvNormalWishart, and I am a bit confused about it: its samples are not matrices, but its variate form is |
@bvdmitri We have two new traits objects in this PR: |
|
||
function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x) | ||
# TODO: Think of what to do with this assert | ||
@assert insupport(ef, x) | ||
_logpartition = logpartition(ef) | ||
return _logpdf(ef, x, _logpartition) | ||
end | ||
|
||
function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x, logpartition) | ||
# TODO: Think of what to do with this assert | ||
@assert insupport(ef, x) | ||
η = getnaturalparameters(ef) | ||
# Use `_` to avoid name collisions with the actual functions | ||
_statistics = sufficientstatistics(ef, x) | ||
_basemeasure = basemeasure(ef, x) | ||
return log(_basemeasure) + dot(η, flatten_parameters(MvNormalWishart, _statistics)) - logpartition | ||
end | ||
|
||
function _pdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x) | ||
exp(_logpdf(ef, x)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are these methods for? They duplicate the generic function
You can either
|
|
||
function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x) | ||
# TODO: Think of what to do with this assert | ||
@assert insupport(ef, x) | ||
_logpartition = logpartition(ef) | ||
return _logpdf(ef, x, _logpartition) | ||
end | ||
|
||
function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x, logpartition) | ||
# TODO: Think of what to do with this assert | ||
@assert insupport(ef, x) | ||
η = getnaturalparameters(ef) | ||
# Use `_` to avoid name collisions with the actual functions | ||
_statistics = sufficientstatistics(ef, x) | ||
_basemeasure = basemeasure(ef, x) | ||
return log(_basemeasure) + dot(η, flatten_parameters(MvNormalWishart, _statistics)) - logpartition | ||
end | ||
|
||
function _pdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x) | ||
exp(_logpdf(ef, x)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x) | |
# TODO: Think of what to do with this assert | |
@assert insupport(ef, x) | |
_logpartition = logpartition(ef) | |
return _logpdf(ef, x, _logpartition) | |
end | |
function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x, logpartition) | |
# TODO: Think of what to do with this assert | |
@assert insupport(ef, x) | |
η = getnaturalparameters(ef) | |
# Use `_` to avoid name collisions with the actual functions | |
_statistics = sufficientstatistics(ef, x) | |
_basemeasure = basemeasure(ef, x) | |
return log(_basemeasure) + dot(η, flatten_parameters(MvNormalWishart, _statistics)) - logpartition | |
end | |
function _pdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x) | |
exp(_logpdf(ef, x)) | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can explain my reasoning for this issue.
My generic strategy is failing for this particular distribution: MvNormalWishart is not a matrix variate distribution, yet it possesses this variate type. I am uncertain why it has this variate type, so to avoid altering this I decided to implement a specific realisation for this distribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost everything looks very good, please address minor comments 👍
Indeed Normal-Wishart distribution is mix-variate and referring to it as Matrixvariate is not ideal. We should consider it in a seperate PR. I am ok with the new strategy. |
I like the idea of making it a mix-variate type (we can do it with abstract types and tuples). But indeed it probably should be a separate PR. Regarding the failing documentation build see my previous comment |
Co-authored-by: Bagaev Dmitry <[email protected]>
@bvdmitri I resolved the issue with the documentation. There are several solutions how the MvNormalWishart issue can be resolved, but let's document our decision somewhere and after I can implement it in a separate PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks, @Nimrais !
In this issue it was shown that
logpdf
evaluation inside theExponentialFamily.jl
is far slower then it theoretically could be.This PR aims to partially resolve this issue trough two following ideas:
1. Evaluate logpartition only once for a container of points.
2. Use
pack_parameters
instead offlatten_parameters,
for multivariate distributions.To benchmark
logpdf
I am using the following script:Old evaluation strategy
There are two outputs the speed of
logpdf
evaluation without naïvly of thelogpartition
(benchmark_old_logpdf):BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 51.875 μs … 2.353 ms ┊ GC (min … max): 0.00% … 95.82%
Time (median): 53.083 μs ┊ GC (median): 0.00%
Time (mean ± σ): 61.119 μs ± 121.205 μs ┊ GC (mean ± σ): 11.50% ± 5.63%
▄▅██▇▅▄▃▃▃▂▃▂▂▂▁▁▁▁ ▁▁ ▂
███████████████████▇████▇▇▇███▇▇▇▇▇▆▅▅▄▆▅▅▃▄▄▄▄▄▄▄▄▁▄▁▃▄▅▅▄▄ █
51.9 μs Histogram: log(frequency) by time 72 μs <
Memory estimate: 164.19 KiB, allocs estimate: 2001.
BenchmarkTools.Trial: 9754 samples with 1 evaluation.
Range (min … max): 449.666 μs … 1.863 ms ┊ GC (min … max): 0.00% … 68.39%
Time (median): 481.959 μs ┊ GC (median): 0.00%
Time (mean ± σ): 511.372 μs ± 172.085 μs ┊ GC (mean ± σ): 4.95% ± 10.20%
▅█▆▄▃▁ ▁
███████▇▆▄▄▃▁▁▁▁▃▁▁▃▁▁▁▁▁▁▁▁▁▁▃▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▄▇██ █
450 μs Histogram: log(frequency) by time 1.63 ms <
Memory estimate: 1007.94 KiB, allocs estimate: 5001.
New evaluation strategy
There are two outputs the speed of
logpdf
evaluation with only onelogpartition
evaluation (benchmark_new_logpdf):BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 43.750 μs … 2.385 ms ┊ GC (min … max): 0.00% … 95.76%
Time (median): 45.167 μs ┊ GC (median): 0.00%
Time (mean ± σ): 53.773 μs ± 135.426 μs ┊ GC (mean ± σ): 14.61% ± 5.67%
▂▄▆▇█▆▃▁ ▁▁ ▁
█████████▇▇▇▆▇████▇▆▆▆▅▄▅▇▆▇▇▇▆▇▆▆▆▆▇▆▇▆▆▆▅▅▅▅▆▅▅▆▄▅▅▅▄▄▅▄▄▄ █
43.8 μs Histogram: log(frequency) by time 59.9 μs <
Memory estimate: 164.19 KiB, allocs estimate: 2001.
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 109.875 μs … 1.425 ms ┊ GC (min … max): 0.00% … 89.07%
Time (median): 117.541 μs ┊ GC (median): 0.00%
Time (mean ± σ): 134.204 μs ± 135.941 μs ┊ GC (mean ± σ): 11.87% ± 10.49%
█▂ ▁
███▄▄▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ █
110 μs Histogram: log(frequency) by time 1.25 ms <
Memory estimate: 633.31 KiB, allocs estimate: 3003.