Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sampling, integration and mode-estimation also return an EvaluatedMeasure #459

Merged
merged 9 commits into from
Nov 4, 2024
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ run more or less unchanged (with deprecation warnings). Also:

### New features

* Sampling, integration and mode-finding algorithms now generate a return
value `result = ..., evaluated::EvaluatedMeasure = ..., ...)` if their
target is a probability measure/distribution.

* The new `RAMTuning` is now the default (transform) tuning algorithm for
`RandomWalk` (formerly `MetropolisHastings`). It typically results in a much
faster burn-in process than `AdaptiveAffineTuning` (formerly
Expand Down Expand Up @@ -112,6 +116,8 @@ BAT.jl v3.0.0

* Use the new function `bat_report` to generate a sampling output report instead of `show(BAT.SampledDensity(samples))`.

* The field types of `EvaluatedMeasure` have changed.


### New features
------------
Expand Down
2 changes: 2 additions & 0 deletions docs/src/stable_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ TransformAlgorithm
TransformedMCMC
VEGASIntegration

BAT.unevaluated

BAT.AbstractMedianEstimator
BAT.AbstractModeEstimator
BAT.AbstractSamplingAlgorithm
Expand Down
10 changes: 4 additions & 6 deletions examples/paper-example/paper_example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,23 +135,21 @@ prior_bkg = distbind(make_child_prior(length(summary_dataset_table)), parent_pr

prior_bkg_signal = distbind(make_child_prior(length(summary_dataset_table)), parent_prior_bkg_signal, merge)

posterior_bkg = PosteriorMeasure(make_likelihood_bkg(summary_dataset_table, sample_table), prior_bkg)
posterior_bkg = lbqintegral(make_likelihood_bkg(summary_dataset_table, sample_table), prior_bkg)

posterior_bkg_signal = PosteriorMeasure(SignalBkgLikelihood(summary_dataset_table, sample_table), prior_bkg_signal)
posterior_bkg_signal = lbqintegral(SignalBkgLikelihood(summary_dataset_table, sample_table), prior_bkg_signal)

nchains = 4
nsteps = 10^5

algorithm = TransformedMCMC(proposal = HamiltonianMC(), nchains = nchains, nsteps = nsteps)

samples_bkg = bat_sample(posterior_bkg, algorithm).result
eval_bkg = EvaluatedMeasure(posterior_bkg, samples = samples_bkg)
samples_bkg, eval_bkg = bat_sample(posterior_bkg, algorithm)

@show evidence_bkg_bridge = bat_integrate(eval_bkg, BridgeSampling()).result
@show evidence_bkg_cuba = bat_integrate(eval_bkg, VEGASIntegration(maxevals = 10^6, rtol = 0.005)).result

samples_bkg_signal = bat_sample(posterior_bkg_signal, algorithm).result
eval_bkg_signal = EvaluatedMeasure(posterior_bkg_signal, samples = samples_bkg_signal)
samples_bkg_signal, eval_bkg_signal = bat_sample(posterior_bkg_signal, algorithm)

@show evidence_bkg_signal_bridge = bat_integrate(eval_bkg_signal, BridgeSampling()).result
@show evidence_bkg_signal_cuba = bat_integrate(eval_bkg_signal, VEGASIntegration(maxevals = 10^6, rtol = 0.005)).result
Expand Down
5 changes: 3 additions & 2 deletions ext/BATCubaExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Cuba
using BAT
BAT.pkgext(::Val{:Cuba}) = BAT.PackageExtension{:Cuba}()

using BAT: MeasureLike, BATMeasure
using BAT: MeasureLike, BATMeasure, unevaluated
using BAT: CubaIntegration
using BAT: measure_support, bat_integrate_impl
using BAT: transform_and_unshape, auto_renormalize
Expand Down Expand Up @@ -130,8 +130,9 @@ function BAT.bat_integrate_impl(target::MeasureLike, algorithm::CubaIntegration,
end

renormalized_measure, logweight = auto_renormalize(transformed_measure)
renormalized_measure_uneval = unevaluated(renormalized_measure)
dof = totalndof(varshape(renormalized_measure))
integrand = CubaIntegrand(logdensityof(renormalized_measure), dof)
integrand = CubaIntegrand(logdensityof(renormalized_measure_uneval), dof)

r_cuba = _integrate_impl_cuba(integrand, algorithm, context)

Expand Down
13 changes: 6 additions & 7 deletions ext/BATMGVIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ import MGVI
using MGVI: MGVIContext, MGVIConfig, MGVIResult, mgvi_step, mgvi_sample

import BAT
import MGVI
using MGVI: MGVIContext, MGVIConfig, MGVIResult, mgvi_step, mgvi_sample

import BAT
BAT.pkgext(::Val{:MGVI}) = BAT.PackageExtension{:MGVI}()

using BAT: MeasureLike, BATMeasure, DensitySample, DensitySampleVector, BATContext
using BAT: MeasureLike, BATMeasure, DensitySample, DensitySampleVector, BATContext, unevaluated
using BAT: transform_and_unshape, bat_initval, apply_trafo_to_init, exec_map!
using BAT: getlikelihood, getprior, StandardMvNormal
using BAT: checked_logdensityof
Expand Down Expand Up @@ -79,8 +77,9 @@ function BAT.bat_sample_impl(m::BATMeasure, algorithm::MGVISampling, context::BA
mgvi_context = MGVIContext(get_gencontext(context), BAT._get_checked_adselector(context, :MGVISampling))

transformed_m, f_pretransform = transform_and_unshape(pretransform, m, context)
transformed_m_uneval = unevaluated(transformed_m)

likelihood, prior = getlikelihood(transformed_m), getprior(transformed_m)
likelihood, prior = getlikelihood(transformed_m_uneval), getprior(transformed_m_uneval)
if !is_std_mvnormal(prior)
throw(ArgumentError("$(nameof(typeof(algorithm))) can't be used for measures that do not have a standard multivariate normal prior after `pretransform`"))
end
Expand Down Expand Up @@ -119,7 +118,7 @@ function BAT.bat_sample_impl(m::BATMeasure, algorithm::MGVISampling, context::BA
dummy_sample = DensitySample(center, zero(result.mnlp), one(BAT._IntWeightType), MGVISampleInfo(0, false, zero(result.mnlp)), nothing)
transformed_smpls = DensitySampleVector(typeof(dummy_sample), length(center))
if store_unconverged
_append_mgvi_samples!(transformed_smpls, transformed_m, result.samples, MGVISampleInfo(nsteps, false, result.mnlp))
_append_mgvi_samples!(transformed_smpls, transformed_m_uneval, result.samples, MGVISampleInfo(nsteps, false, result.mnlp))
end

isdone::Bool = false
Expand All @@ -130,7 +129,7 @@ function BAT.bat_sample_impl(m::BATMeasure, algorithm::MGVISampling, context::BA
result, center = mgvi_step(f_model, obs, step_nsamples, center, config, mgvi_context)
nsteps += 1
if store_unconverged
_append_mgvi_samples!(transformed_smpls, transformed_m, result.samples, MGVISampleInfo(nsteps, false, result.mnlp))
_append_mgvi_samples!(transformed_smpls, transformed_m_uneval, result.samples, MGVISampleInfo(nsteps, false, result.mnlp))
end
else
isdone = true
Expand All @@ -150,7 +149,7 @@ function BAT.bat_sample_impl(m::BATMeasure, algorithm::MGVISampling, context::BA
nsteps += 1
n_samples_total = size(final_flat_smpls, 2)
n_samples_indep = div(n_samples_total, 2)
_append_mgvi_samples!(transformed_smpls, transformed_m, final_flat_smpls, MGVISampleInfo(nsteps, true, oftype(result.mnlp,NaN)))
_append_mgvi_samples!(transformed_smpls, transformed_m_uneval, final_flat_smpls, MGVISampleInfo(nsteps, true, oftype(result.mnlp,NaN)))

elapsed_time = time() - start_time
@debug "Generated final MGVI samples in transformed space after $nsteps, produced $n_samples_indep independent samples after $(@sprintf "%.1f s" elapsed_time)."
Expand Down
13 changes: 7 additions & 6 deletions ext/BATNestedSamplersExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using HeterogeneousComputing

BAT.pkgext(::Val{:NestedSamplers}) = BAT.PackageExtension{:NestedSamplers}()

using BAT: MeasureLike, BATMeasure
using BAT: MeasureLike, BATMeasure, unevaluated
using BAT: ENSBound, ENSNoBounds, ENSEllipsoidBound, ENSMultiEllipsoidBound
using BAT: ENSProposal, ENSUniformly, ENSAutoProposal, ENSRandomWalk, ENSSlice

Expand Down Expand Up @@ -63,14 +63,15 @@ function BAT.bat_sample_impl(m::BATMeasure, algorithm::EllipsoidalNestedSampling
# ToDo: Forward RNG from context!
rng = get_rng(context)

transformed_m, f_pretransform = BAT.transform_and_unshape(algorithm.pretransform, m, context) # BAT prior transformation
dims = totalndof(varshape(transformed_m))
transformed_m, f_pretransform = BAT.transform_and_unshape(algorithm.pretransform, m, context)
transformed_m_uneval = unevaluated(transformed_m)
dims = totalndof(varshape(transformed_m_uneval))

if !BAT.has_uhc_support(transformed_m)
if !BAT.has_uhc_support(transformed_m_uneval)
throw(ArgumentError("$algorithm doesn't measures that are not limited to the unit hypercube"))
end

model = NestedModel(logdensityof(transformed_m), identity); # identity, because ahead the BAT prior transformation is used instead
model = NestedModel(logdensityof(transformed_m_uneval), identity); # identity, because ahead the BAT prior transformation is used instead
bounding = ENSBounding(algorithm.bound)
prop = ENSprop(algorithm.proposal)
sampler = Nested(
Expand All @@ -87,7 +88,7 @@ function BAT.bat_sample_impl(m::BATMeasure, algorithm::EllipsoidalNestedSampling
weights = samples_w[:, end] # the last elements of the vectors are the weights
nsamples = size(samples_w,1)
samples = [samples_w[i, 1:end-1] for i in 1:nsamples] # the other ones (between 1 and end-1) are the samples
logvals = map(logdensityof(transformed_m), samples) # posterior values of the samples
logvals = map(logdensityof(transformed_m_uneval), samples) # posterior values of the samples
transformed_smpls = BAT.DensitySampleVector(samples, logvals, weight = weights)
smpls = inverse(f_pretransform).(transformed_smpls) # Here the samples are retransformed

Expand Down
5 changes: 3 additions & 2 deletions ext/BATOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using DensityInterface, ChangesOfVariables, InverseFunctions, FunctionChains
using HeterogeneousComputing, AutoDiffOperators
using StructArrays, ArraysOfArrays

using BAT: MeasureLike, BATMeasure
using BAT: MeasureLike, BATMeasure, unevaluated

using BAT: get_context, get_adselector, _NoADSelected
using BAT: bat_initval, transform_and_unshape, apply_trafo_to_init
Expand Down Expand Up @@ -69,13 +69,14 @@ end

function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimAlg, context::BATContext)
transformed_density, f_pretransform = transform_and_unshape(algorithm.pretransform, target, context)
target_uneval = unevaluated(target)
inv_trafo = inverse(f_pretransform)

initalg = apply_trafo_to_init(f_pretransform, algorithm.init)
x_init = collect(bat_initval(transformed_density, initalg, context).result)

# Maximize density of original target, but run in transformed space, don't apply LADJ:
f = fchain(inv_trafo, logdensityof(target), -)
f = fchain(inv_trafo, logdensityof(target_uneval), -)
opts = convert_options(algorithm)
optim_result = _optim_minimize(f, x_init, algorithm.optalg, opts, context)
r_optim = Optim.MaximizationWrapper(optim_result)
Expand Down
9 changes: 5 additions & 4 deletions ext/BATOptimizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using DensityInterface, ChangesOfVariables, InverseFunctions, FunctionChains
using HeterogeneousComputing, AutoDiffOperators
using StructArrays, ArraysOfArrays, ADTypes

using BAT: MeasureLike
using BAT: MeasureLike, unevaluated

using BAT: get_context, get_adselector, _NoADSelected
using BAT: bat_initval, transform_and_unshape, apply_trafo_to_init
Expand Down Expand Up @@ -42,14 +42,15 @@ end


function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimizationAlg, context::BATContext)
transformed_density, f_pretransform = transform_and_unshape(algorithm.pretransform, target, context)
transformed_m, f_pretransform = transform_and_unshape(algorithm.pretransform, target, context)
target_uneval = unevaluated(target)
inv_trafo = inverse(f_pretransform)

initalg = apply_trafo_to_init(f_pretransform, algorithm.init)
x_init = collect(bat_initval(transformed_density, initalg, context).result)
x_init = collect(bat_initval(transformed_m, initalg, context).result)

# Maximize density of original target, but run in transformed space, don't apply LADJ:
f = fchain(inv_trafo, logdensityof(target), -)
f = fchain(inv_trafo, logdensityof(target_uneval), -)
target_f = (x, p) -> f(x)

adsel = get_adselector(context)
Expand Down
69 changes: 43 additions & 26 deletions ext/BATUltraNestExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,50 @@

paramnames = all_active_names(varshape(m))

smplr = UltraNest.ultranest.ReactiveNestedSampler(
paramnames, vec_ultranest_logpstr, vectorized = true,
num_test_samples = algorithm.num_test_samples,
draw_multiple = algorithm.draw_multiple,
num_bootstraps = algorithm.num_bootstraps,
ndraw_min = algorithm.ndraw_min,
ndraw_max = algorithm.ndraw_max
)
ch = Channel()
function run_sampler()
try
smplr = UltraNest.ultranest.ReactiveNestedSampler(

Check warning on line 41 in ext/BATUltraNestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATUltraNestExt.jl#L38-L41

Added lines #L38 - L41 were not covered by tests
paramnames, vec_ultranest_logpstr, vectorized = true,
num_test_samples = algorithm.num_test_samples,
draw_multiple = algorithm.draw_multiple,
num_bootstraps = algorithm.num_bootstraps,
ndraw_min = algorithm.ndraw_min,
ndraw_max = algorithm.ndraw_max
)

unest_result = smplr.run(

Check warning on line 50 in ext/BATUltraNestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATUltraNestExt.jl#L50

Added line #L50 was not covered by tests
log_interval = algorithm.log_interval < 0 ? nothing : algorithm.log_interval,
show_status = algorithm.show_status,
viz_callback = algorithm.viz_callback,
dlogz = algorithm.dlogz,
dKL = algorithm.dKL,
frac_remain = algorithm.frac_remain,
Lepsilon = algorithm.Lepsilon,
min_ess = algorithm.min_ess,
max_iters = algorithm.max_iters < 0 ? nothing : algorithm.max_iters,
max_ncalls = algorithm.max_ncalls < 0 ? nothing : algorithm.max_ncalls,
max_num_improvement_loops = algorithm.max_num_improvement_loops,
min_num_live_points = algorithm.min_num_live_points,
cluster_num_live_points = algorithm.cluster_num_live_points,
insertion_test_window = algorithm.insertion_test_window,
insertion_test_zscore_threshold = algorithm.insertion_test_zscore_threshold
)
put!(ch, unest_result)

Check warning on line 67 in ext/BATUltraNestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATUltraNestExt.jl#L67

Added line #L67 was not covered by tests
finally
close(ch)

Check warning on line 69 in ext/BATUltraNestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATUltraNestExt.jl#L69

Added line #L69 was not covered by tests
end
end

# Force Python interaction to run on thread 1:
task_id = 1
task = Task(run_sampler)
task.sticky = true

Check warning on line 76 in ext/BATUltraNestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATUltraNestExt.jl#L74-L76

Added lines #L74 - L76 were not covered by tests
# From ThreadPools.@tspawnat:
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, task_id-1)
schedule(task)
unest_result = take!(ch)

Check warning on line 80 in ext/BATUltraNestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATUltraNestExt.jl#L78-L80

Added lines #L78 - L80 were not covered by tests

unest_result = smplr.run(
log_interval = algorithm.log_interval < 0 ? nothing : algorithm.log_interval,
show_status = algorithm.show_status,
viz_callback = algorithm.viz_callback,
dlogz = algorithm.dlogz,
dKL = algorithm.dKL,
frac_remain = algorithm.frac_remain,
Lepsilon = algorithm.Lepsilon,
min_ess = algorithm.min_ess,
max_iters = algorithm.max_iters < 0 ? nothing : algorithm.max_iters,
max_ncalls = algorithm.max_ncalls < 0 ? nothing : algorithm.max_ncalls,
max_num_improvement_loops = algorithm.max_num_improvement_loops,
min_num_live_points = algorithm.min_num_live_points,
cluster_num_live_points = algorithm.cluster_num_live_points,
insertion_test_window = algorithm.insertion_test_window,
insertion_test_zscore_threshold = algorithm.insertion_test_zscore_threshold
)

r = convert(Dict{String, Any}, unest_result)

unest_wsamples = convert(Dict{String, Any}, r["weighted_samples"])
Expand Down
1 change: 1 addition & 0 deletions src/BAT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ using MeasureBase: transport_to, transport_origin, from_origin, to_origin
using MeasureBase: StdMeasure, StdUniform, StdNormal
using MeasureBase: PowerMeasure, powermeasure, marginals
using MeasureBase: WeightedMeasure, weightedmeasure
using MeasureBase: massof

using MeasureBase: PushforwardMeasure, gettransform
using MeasureBase: TransformVolCorr as PushFwdStyle, NoVolCorr as ChangeRootMeasure, WithVolCorr as KeepRootMeasure
Expand Down
4 changes: 4 additions & 0 deletions src/algodefaults/default_mode_estimator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@

bat_default(::typeof(bat_findmode), ::Val{:algorithm}, ::BATDistMeasure) = ModeAsDefined()

function bat_default(::typeof(bat_findmode), ::Val{:algorithm}, m::EvaluatedMeasure)
bat_default(bat_findmode, Val(:algorithm), unevaluated(m))

Check warning on line 16 in src/algodefaults/default_mode_estimator.jl

View check run for this annotation

Codecov / codecov/patch

src/algodefaults/default_mode_estimator.jl#L15-L16

Added lines #L15 - L16 were not covered by tests
end

bat_default(::typeof(bat_marginalmode), ::Val{:algorithm}, ::DensitySampleVector) = BinnedModeEstimator()
2 changes: 1 addition & 1 deletion src/algodefaults/default_sampling_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
bat_default(::typeof(bat_sample), ::Val{:algorithm}, ::PosteriorMeasure) = TransformedMCMC()

function bat_default(::typeof(bat_sample), ::Val{:algorithm}, m::EvaluatedMeasure)
bat_default(bat_sample, Val(:algorithm), m.measure)
bat_default(bat_sample, Val(:algorithm), unevaluated(m))

Check warning on line 18 in src/algodefaults/default_sampling_algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algodefaults/default_sampling_algorithm.jl#L18

Added line #L18 was not covered by tests
end
63 changes: 63 additions & 0 deletions src/algotypes/bat_default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,66 @@
result_with_args(r::NamedTuple) = merge(r, (optargs = NamedTuple(),))

result_with_args(r::NamedTuple, optargs::NamedTuple) = merge(r, (optargs = optargs,))

function result_with_args(::Val, ::Any, r::NamedTuple, optargs::NamedTuple)
return result_with_args(r, optargs)
end

function result_with_args(::Val{resultname}, target::Union{AbstractMeasure,Distribution}, r::NamedTuple, optargs::NamedTuple) where resultname
measure = batmeasure(target)
augmented_result = _augment_bat_retval(Val(resultname), measure, r)
result_with_args(augmented_result, optargs)
end

function _augment_bat_retval(::Val{resultname}, measure, r::R) where {resultname,R}
if hasfield(R, :evaluated)
return r

Check warning on line 70 in src/algotypes/bat_default.jl

View check run for this annotation

Codecov / codecov/patch

src/algotypes/bat_default.jl#L70

Added line #L70 was not covered by tests
else
if resultname == :samples
samples = r.result
elseif hasfield(R, :samples)
samples = r.samples

Check warning on line 75 in src/algotypes/bat_default.jl

View check run for this annotation

Codecov / codecov/patch

src/algotypes/bat_default.jl#L75

Added line #L75 was not covered by tests
else
samples = maybe_samplesof(measure)
end

if resultname == :approx
approx = r.result

Check warning on line 81 in src/algotypes/bat_default.jl

View check run for this annotation

Codecov / codecov/patch

src/algotypes/bat_default.jl#L81

Added line #L81 was not covered by tests
elseif hasfield(R, :approx)
approx = r.approx

Check warning on line 83 in src/algotypes/bat_default.jl

View check run for this annotation

Codecov / codecov/patch

src/algotypes/bat_default.jl#L83

Added line #L83 was not covered by tests
else
approx = maybe_approxof(measure)
end

if resultname == :mass
mass = r.result
elseif hasfield(R, :mass)
mass = r.mass
else
mass = massof(measure)
end

if resultname == :modes
modes = r.result

Check warning on line 97 in src/algotypes/bat_default.jl

View check run for this annotation

Codecov / codecov/patch

src/algotypes/bat_default.jl#L97

Added line #L97 was not covered by tests
elseif resultname == :mode
modes = [r.result]
elseif hasfield(R, :modes)
modes = r.modes

Check warning on line 101 in src/algotypes/bat_default.jl

View check run for this annotation

Codecov / codecov/patch

src/algotypes/bat_default.jl#L101

Added line #L101 was not covered by tests
elseif hasfield(R, :mode)
modes = [r.mode]

Check warning on line 103 in src/algotypes/bat_default.jl

View check run for this annotation

Codecov / codecov/patch

src/algotypes/bat_default.jl#L103

Added line #L103 was not covered by tests
else
modes = maybe_modesof(measure)
end

if resultname == :generator
generator = r.result

Check warning on line 109 in src/algotypes/bat_default.jl

View check run for this annotation

Codecov / codecov/patch

src/algotypes/bat_default.jl#L109

Added line #L109 was not covered by tests
elseif hasfield(R, :generator)
generator = r.generator
else
generator = maybe_generator(measure)
end
evaluated = EvaluatedMeasure(unevaluated(measure), samples, approx, mass, modes, generator)
r_add = (result = r.result, evaluated = evaluated)
return merge(r_add, r)
end
end
Loading
Loading