Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove LogDensityProblemsAD #806

Draft
wants to merge 1 commit into
base: release-0.35
Choose a base branch
from
Draft

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Feb 10, 2025

Tests are not passing yet - am on it
Tests pass! 🎉

Now need to benchmark performance before vs after this PR.
Benchmarks are in the comment below.

Requires TuringLang/AbstractMCMC.jl#158 to be merged first.

Personally I still have some concerns over this - I'm not sure if prep should be moved into the LogDensityFunction struct.

@penelopeysm penelopeysm changed the base branch from master to release-0.35 February 10, 2025 13:19
Copy link

codecov bot commented Feb 10, 2025

Codecov Report

Attention: Patch coverage is 50.00000% with 5 lines in your changes missing coverage. Please review.

Project coverage is 85.60%. Comparing base (7613dbb) to head (8e22c05).

Files with missing lines Patch % Lines
src/logdensityfunction.jl 62.50% 3 Missing ⚠️
src/contexts.jl 0.00% 1 Missing ⚠️
src/sampler.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@               Coverage Diff                @@
##           release-0.35     #806      +/-   ##
================================================
- Coverage         85.78%   85.60%   -0.19%     
================================================
  Files                36       36              
  Lines              4207     4195      -12     
================================================
- Hits               3609     3591      -18     
- Misses              598      604       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@TuringLang TuringLang deleted a comment from github-actions bot Feb 10, 2025
@TuringLang TuringLang deleted a comment from github-actions bot Feb 10, 2025
@TuringLang TuringLang deleted a comment from github-actions bot Feb 10, 2025
@coveralls
Copy link

coveralls commented Feb 10, 2025

Pull Request Test Coverage Report for Build 13243689609

Details

  • 0 of 10 (0.0%) changed or added relevant lines in 3 files are covered.
  • 2859 unchanged lines in 27 files lost coverage.
  • Overall coverage decreased (-60.3%) to 25.584%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/contexts.jl 0 1 0.0%
src/sampler.jl 0 1 0.0%
src/logdensityfunction.jl 0 8 0.0%
Files with Coverage Reduction New Missed Lines %
ext/DynamicPPLForwardDiffExt.jl 1 0.0%
src/selector.jl 2 0.0%
src/varname.jl 6 0.0%
src/test_utils/model_interface.jl 7 0.0%
src/model_utils.jl 11 0.0%
src/test_utils/contexts.jl 12 0.0%
src/distribution_wrappers.jl 13 0.0%
src/logdensityfunction.jl 13 0.0%
src/test_utils/varinfo.jl 23 0.0%
src/submodel_macro.jl 26 0.0%
Totals Coverage Status
Change from base Build 13229447007: -60.3%
Covered Lines: 1062
Relevant Lines: 4151

💛 - Coveralls

@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 10, 2025

Benchmarks

Code

New version -- run on this PR

using Test, DynamicPPL, ADTypes, LogDensityProblems, Chairmarks, StatsBase, Random
import ForwardDiff
import ReverseDiff
import Mooncake

cmarks = []
for m in DynamicPPL.TestUtils.DEMO_MODELS
    vi = DynamicPPL.VarInfo(Xoshiro(468), m)
    f = DynamicPPL.LogDensityFunction(m, vi)
    θ = convert(Vector{Float64}, vi[:])
    for adtype in [AutoForwardDiff(), AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true), AutoMooncake(; config=nothing)]
        t = @be LogDensityProblems.logdensity_and_gradient($f, $θ, $adtype) evals=1
        push!(cmarks, (string(m.f) * "," * string(adtype), t))
    end
end
for cmark in cmarks
    println("$(cmark[1]),$(median(cmark[2]).time)")
end

Old version -- run on master branch

using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Chairmarks, StatsBase, Random
import ForwardDiff
import ReverseDiff
import Mooncake
import DifferentiationInterface
cmarks = []
for m in DynamicPPL.TestUtils.DEMO_MODELS
    vi = DynamicPPL.VarInfo(Xoshiro(468), m)
    f = DynamicPPL.LogDensityFunction(m, vi)
    θ = convert(Vector{Float64}, vi[:])
    for adtype in [AutoForwardDiff(), AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true), AutoMooncake(; config=nothing)]
        t = @be LogDensityProblems.logdensity_and_gradient(DynamicPPL._make_ad_gradient($adtype, $f), $θ) evals=1
        push!(cmarks, (string(m.f) * "," * string(adtype), t))
    end
end
for cmark in cmarks
    println("$(cmark[1]),$(median(cmark[2]).time)")
end

Summary

ForwardDiff and Mooncake have pretty much identical performance to before.

This PR makes non-compiled ReverseDiff 1.2x slower, and compiled ReverseDiff 2x faster. Unclear why.

This is generally true across all demo models.

Results

model AD backend new old New/old
demo_dot_assume_dot_observe AutoForwardDiff() 2.125E-06 2.167E-06 0.98061836640517
demo_dot_assume_dot_observe AutoReverseDiff() 2.1667E-05 1.675E-05 1.2936
demo_dot_assume_dot_observe AutoReverseDiff(compile=true) 2.1792E-05 4.175E-05 0.521964071856287
demo_dot_assume_dot_observe AutoMooncake{Nothing}(nothing) 0.00017491700000000000 0.000176042 0.993609479556015
demo_assume_index_observe AutoForwardDiff() 1.666E-06 2.208E-06 0.75452898550725
demo_assume_index_observe AutoReverseDiff() 2.3791E-05 1.9667E-05 1.2097
demo_assume_index_observe AutoReverseDiff(compile=true) 2.35E-05 4.7917E-05 0.4904
demo_assume_index_observe AutoMooncake{Nothing}(nothing) 8.1042E-05 8.3541E-05 0.970086544331526
demo_assume_multivariate_observe AutoForwardDiff() 1.625E-06 1.75E-06 0.92857142857143
demo_assume_multivariate_observe AutoReverseDiff() 2.075E-05 1.7291E-05 1.20004626684402
demo_assume_multivariate_observe AutoReverseDiff(compile=true) 2.0833E-05 4.0917E-05 0.5092
demo_assume_multivariate_observe AutoMooncake{Nothing}(nothing) 0.00010254200000000000 0.00010170800000000000 1.00819994494042
demo_dot_assume_observe_index AutoForwardDiff() 2.167E-06 2.625E-06 0.82552380952381
demo_dot_assume_observe_index AutoReverseDiff() 2.2125E-05 1.7917E-05 1.2349
demo_dot_assume_observe_index AutoReverseDiff(compile=true) 2.2084E-05 4.4916E-05 0.491673345801051
demo_dot_assume_observe_index AutoMooncake{Nothing}(nothing) 0.00014900000000000000 0.00015325 0.972267536704731
demo_assume_dot_observe AutoForwardDiff() 1.375E-06 1.417E-06 0.97035991531404
demo_assume_dot_observe AutoReverseDiff() 1.6458E-05 1.4042E-05 1.17205526278308
demo_assume_dot_observe AutoReverseDiff(compile=true) 1.65E-05 3.525E-05 0.4681
demo_assume_dot_observe AutoMooncake{Nothing}(nothing) 7.3E-05 7.5125E-05 0.97171381031614
demo_assume_multivariate_observe_literal AutoForwardDiff() 1.667E-06 1.709E-06 0.975424224692803
demo_assume_multivariate_observe_literal AutoReverseDiff() 2.0792E-05 1.7125E-05 1.2141
demo_assume_multivariate_observe_literal AutoReverseDiff(compile=true) 2.0833E-05 4.0125E-05 0.5192
demo_assume_multivariate_observe_literal AutoMooncake{Nothing}(nothing) 9.3209E-05 9.25625E-05 1.00698446995273
demo_dot_assume_observe_index_literal AutoForwardDiff() 2.167E-06 2.542E-06 0.85247836349331
demo_dot_assume_observe_index_literal AutoReverseDiff() 2.2333E-05 1.7667E-05 1.2641
demo_dot_assume_observe_index_literal AutoReverseDiff(compile=true) 2.25E-05 4.5208E-05 0.4977
demo_dot_assume_observe_index_literal AutoMooncake{Nothing}(nothing) 0.00015075 0.00015191600000000000 0.992324705758445
demo_assume_dot_observe_literal AutoForwardDiff() 1.333E-06 1.375E-06 0.96945454545455
demo_assume_dot_observe_literal AutoReverseDiff() 1.6708E-05 1.4E-05 1.1934
demo_assume_dot_observe_literal AutoReverseDiff(compile=true) 1.6708E-05 3.4583E-05 0.483127548217332
demo_assume_dot_observe_literal AutoMooncake{Nothing}(nothing) 6.5792E-05 6.375E-05 1.03203137254902
demo_assume_observe_literal AutoForwardDiff() 1.334E-06 1.375E-06 0.97018181818182
demo_assume_observe_literal AutoReverseDiff() 1.6792E-05 1.4291E-05 1.17500524805822
demo_assume_observe_literal AutoReverseDiff(compile=true) 1.6833E-05 3.5708E-05 0.471406967626302
demo_assume_observe_literal AutoMooncake{Nothing}(nothing) 6.125E-05 6.2375E-05 0.981963927855711
demo_assume_submodel_observe_index_literal AutoForwardDiff() 2.417E-06 2.875E-06 0.840695652173913
demo_assume_submodel_observe_index_literal AutoReverseDiff() 2.3083E-05 1.7875E-05 1.2914
demo_assume_submodel_observe_index_literal AutoReverseDiff(compile=true) 2.4917E-05 4.5375E-05 0.549134986225895
demo_assume_submodel_observe_index_literal AutoMooncake{Nothing}(nothing) 0.00015558300000000000 0.00015577050000000000 0.998796306104172
demo_dot_assume_observe_submodel AutoForwardDiff() 2.917E-06 3.584E-06 0.813895089285714
demo_dot_assume_observe_submodel AutoReverseDiff() 2.3875E-05 1.7792E-05 1.3419
demo_dot_assume_observe_submodel AutoReverseDiff(compile=true) 2.3875E-05 4.3334E-05 0.550953062260581
demo_dot_assume_observe_submodel AutoMooncake{Nothing}(nothing) 0.000186875 0.00018412500000000000 1.01493550577054
demo_dot_assume_dot_observe_matrix AutoForwardDiff() 2.084E-06 2.625E-06 0.793904761904762
demo_dot_assume_dot_observe_matrix AutoReverseDiff() 2.1708E-05 1.7167E-05 1.2645
demo_dot_assume_dot_observe_matrix AutoReverseDiff(compile=true) 2.2E-05 4.2208E-05 0.52122820318423
demo_dot_assume_dot_observe_matrix AutoMooncake{Nothing}(nothing) 0.0001784585 0.0001756875 1.01577232301672
demo_dot_assume_matrix_dot_observe_matrix AutoForwardDiff() 2.125E-06 2.334E-06 0.91045415595544
demo_dot_assume_matrix_dot_observe_matrix AutoReverseDiff() 2.2791E-05 1.9083E-05 1.1943
demo_dot_assume_matrix_dot_observe_matrix AutoReverseDiff(compile=true) 2.2833E-05 4.4458E-05 0.513585856313824
demo_dot_assume_matrix_dot_observe_matrix AutoMooncake{Nothing}(nothing) 0.00019050000000000000 0.000183687 1.03709026768361
demo_assume_matrix_dot_observe_matrix AutoForwardDiff() 1.708E-06 1.833E-06 0.93180578286961
demo_assume_matrix_dot_observe_matrix AutoReverseDiff() 2.1875E-05 1.825E-05 1.1986
demo_assume_matrix_dot_observe_matrix AutoReverseDiff(compile=true) 2.1875E-05 4.1417E-05 0.528164763261463
demo_assume_matrix_dot_observe_matrix AutoMooncake{Nothing}(nothing) 0.000117041 0.000115333 1.01480929135633

Comment on lines +123 to 141
# By default, the AD backend to use is inferred from the context, which would
# typically be a SamplingContext which contains a sampler.
function LogDensityProblems.logdensity_and_gradient(
f::LogDensityFunction, θ::AbstractVector
)
adtype = getadtype(getsampler(getcontext(f)))
return LogDensityProblems.logdensity_and_gradient(f, θ, adtype)
end
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction)
return _make_ad_gradient(ad, f)

# Extra method allowing one to manually specify the AD backend to use, thus
# overriding the default AD backend inferred from the sampler.
function LogDensityProblems.logdensity_and_gradient(
f::LogDensityFunction, θ::AbstractVector, adtype::ADTypes.AbstractADType
)
# Ensure we concretise the elements of the params.
θ = map(identity, θ) # TODO: Is this needed?
prep = DI.prepare_gradient(_flipped_logdensity, adtype, θ, DI.Constant(f))
return DI.value_and_gradient(_flipped_logdensity, prep, adtype, θ, DI.Constant(f))
end
Copy link
Member Author

@penelopeysm penelopeysm Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm not fully convinced with what I've done so far. The code here means that every time we call logdensity_and_gradient we run DI.prepare_gradient again, which seems to be inefficient.

I guess the way to get around this would be to call DI.prepare_gradient once and store that prep inside the LogDensityFunction, or (IMO better) create a new type LogDensityFunctionPrepped that wraps the LogDensityFunction plus the prep, and then call logdensity_and_gradient on that object. That would mean that we've basically reinvented / inlined the LogDensityProblemsAD.ADgradient type 😄

I'm not opposed to doing that if that's the best way forward, but I wonder if you concur with / see any holes in my assessment @willtebbutt.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants