Skip to content

Commit

Permalink
Merge pull request #202 from lxvm/mcintegration
Browse files Browse the repository at this point in the history
wrap MCintegration.jl
  • Loading branch information
ChrisRackauckas authored Jan 7, 2024
2 parents 7f71009 + 90162bb commit d84e854
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 7 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
Expand All @@ -27,6 +28,7 @@ IntegralsCubaExt = "Cuba"
IntegralsCubatureExt = "Cubature"
IntegralsFastGaussQuadratureExt = "FastGaussQuadrature"
IntegralsForwardDiffExt = "ForwardDiff"
IntegralsMCIntegrationExt = "MCIntegration"
IntegralsZygoteExt = ["Zygote", "ChainRulesCore"]

[compat]
Expand All @@ -42,6 +44,7 @@ FiniteDiff = "2.12"
ForwardDiff = "0.10.19"
HCubature = "1.5"
LinearAlgebra = "1.9"
MCIntegration = "0.4.2"
MonteCarloIntegration = "0.0.3, 0.1"
Pkg = "1"
QuadGK = "2.9"
Expand All @@ -63,11 +66,12 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature"]
test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration"]
5 changes: 4 additions & 1 deletion docs/src/solvers/IntegralSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ The following algorithms are available:

- `QuadGKJL`: Uses QuadGK.jl. Requires `nout=1` and `batch=0`, in-place is not allowed.
- `HCubatureJL`: Uses HCubature.jl. Requires `batch=0`.
- `VEGAS`: Uses MonteCarloIntegration.jl. Requires `nout=1`. Works only for `>1`-dimensional integrations.
- `VEGAS`: Uses MonteCarloIntegration.jl. Requires `nout=1`. Works only for
`>1`-dimensional integrations.
- `VEGASMC`: Uses MCIntegration.jl. Requires `using MCIntegration`. Doesn't support batching.
- `CubatureJLh`: h-Cubature from Cubature.jl. Requires `using Cubature`.
- `CubatureJLp`: p-Cubature from Cubature.jl. Requires `using Cubature`.
- `CubaVegas`: Vegas from Cuba.jl. Requires `using Cuba`, `nout=1`.
Expand All @@ -20,6 +22,7 @@ The following algorithms are available:
QuadGKJL
HCubatureJL
VEGAS
VEGASMC
CubaVegas
CubaSUAVE
CubaDivonne
Expand Down
48 changes: 48 additions & 0 deletions ext/IntegralsMCIntegrationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module IntegralsMCIntegrationExt

using MCIntegration, Integrals

function Integrals.__solvebp_call(prob::IntegralProblem, alg::VEGASMC, sensealg, domain, p;
reltol = nothing, abstol = nothing, maxiters = 1000)
lb, ub = domain
mid = vec(collect((lb + ub) / 2))
vars = Continuous(vec([tuple(a,b) for (a,b) in zip(lb, ub)]))

if prob.f isa BatchIntegralFunction
error("VEGASMC doesn't support batching. See https://github.com/numericalEFT/MCIntegration.jl/issues/29")
else
if isinplace(prob)
f0 = similar(prob.f.integrand_prototype)
f_ = (x, f, c) -> begin
n = 0
for v in x
mid[n+=1] = first(v)
end
prob.f(f0, mid, p)
f .= vec(f0)
end
else
f0 = prob.f(mid, p)
f_ = (x, c) -> begin
n = 0
for v in x
mid[n+=1] = first(v)
end
fx = prob.f(mid, p)
fx isa AbstractArray ? vec(fx) : fx
end
end
dof = ones(Int, length(f0)) # each composite Continuous var gets 1 dof
res = integrate(f_, inplace=isinplace(prob), var=vars, dof=dof, solver=:vegasmc,
neval=alg.neval, niter=min(maxiters,alg.niter), block=alg.block, adapt=alg.adapt,
gamma=alg.gamma, verbose=alg.verbose, debug=alg.debug, type=eltype(f0), print=-2)
out, err, chi = if f0 isa Number
map(only, (res.mean, res.stdev, res.chi2))
else
map(a -> reshape(a, size(f0)), (res.mean, res.stdev, res.chi2))
end
SciMLBase.build_solution(prob, VEGASMC(), out, err, chi=chi, retcode = ReturnCode.Success)
end
end

end
3 changes: 2 additions & 1 deletion src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, domain, p;
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
end

export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre, QuadratureRule, TrapezoidalRule

export QuadGKJL, HCubatureJL, VEGAS, VEGASMC, GaussLegendre, QuadratureRule, TrapezoidalRule
export CubaVegas, CubaSUAVE, CubaDivonne, CubaCuhre
export CubatureJLh, CubatureJLp
export ArblibJL
Expand Down
17 changes: 17 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,20 @@ end
function ArblibJL(; check_analytic=false, take_prec=false, warn_on_no_convergence=false, opts=C_NULL)
return ArblibJL(check_analytic, take_prec, warn_on_no_convergence, opts)
end

"""
VEGASMC(; neval=10^4, niter=20, block=16, adapt=true, gamma=1.0, verbose=-2, debug=false)
Markov-chain based Vegas algorithm from MCIntegration.jl
"""
struct VEGASMC <: SciMLBase.AbstractIntegralAlgorithm
neval::Int
niter::Int
block::Int
adapt::Bool
gamma::Float64
verbose::Int
debug::Bool
end
VEGASMC(; neval=10^4, niter=20, block=16, adapt=true, gamma=1.0, verbose=-2, debug=false) =
VEGASMC(neval, niter, block, adapt, gamma, verbose, debug)
8 changes: 4 additions & 4 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
using Integrals
using Cuba, Cubature, Arblib
using Cuba, Cubature, Arblib, MCIntegration
using Test

max_dim_test = 2
max_nout_test = 2
reltol = 1e-3
abstol = 1e-3


algs = [QuadGKJL, HCubatureJL, CubatureJLh, CubatureJLp, #VEGAS, #CubaVegas,
CubaSUAVE, CubaDivonne, CubaCuhre]
algs = [QuadGKJL, HCubatureJL, CubatureJLh, CubatureJLp, VEGAS, VEGASMC, #CubaVegas,
CubaSUAVE, CubaDivonne, CubaCuhre, ArblibJL]

alg_req = Dict(QuadGKJL => (nout = 1, allows_batch = true, min_dim = 1, max_dim = 1,
allows_iip = true),
HCubatureJL => (nout = Inf, allows_batch = false, min_dim = 1,
max_dim = Inf, allows_iip = true),
VEGAS => (nout = 1, allows_batch = true, min_dim = 2, max_dim = Inf,
allows_iip = true),
VEGASMC => (nout = Inf, allows_batch = false, min_dim = 1, max_dim = Inf, allows_iip = true),
CubatureJLh => (nout = Inf, allows_batch = true, min_dim = 1,
max_dim = Inf, allows_iip = true),
CubatureJLp => (nout = Inf, allows_batch = true, min_dim = 1,
Expand Down

0 comments on commit d84e854

Please sign in to comment.