Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gauss-Legendre quadrature #147

Merged
merged 5 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Integrals"
uuid = "de52edbc-65ea-441a-8357-d3a637375a31"
authors = ["Chris Rackauckas <[email protected]>"]
version = "3.6.0"
version = "3.7.0"

[deps]
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Expand All @@ -26,10 +26,12 @@ Requires = "1"
SciMLBase = "1.70"
Zygote = "0.4.22, 0.5, 0.6"
julia = "1.6"
FastGaussQuadrature = "0.5"

[extensions]
IntegralsForwardDiffExt = "ForwardDiff"
IntegralsZygoteExt = ["Zygote", "ChainRulesCore"]
IntegralsFastGaussQuadratureExt = "FastGaussQuadrature"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -42,6 +44,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"

[targets]
test = ["SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore"]
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -50,3 +53,4 @@ test = ["SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets",
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
2 changes: 2 additions & 0 deletions docs/src/solvers/IntegralSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ The following algorithms are available:
- `CubaSUAVE`: SUAVE from Cuba.jl. Requires `using IntegralsCuba`.
- `CubaDivonne`: Divonne from Cuba.jl. Requires `using IntegralsCuba`.
- `CubaCuhre`: Cuhre from Cuba.jl. Requires `using IntegralsCuba`.
- `GaussLegendre`: Uses Gauss-Legendre quadrature with nodes and weights from FastGaussQuadrature.jl.

```@docs
QuadGKJL
HCubatureJL
VEGAS
GaussLegendre
```
51 changes: 51 additions & 0 deletions ext/IntegralsFastGaussQuadratureExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
module IntegralsFastGaussQuadratureExt
using Integrals
if isdefined(Base, :get_extension)
import FastGaussQuadrature
import FastGaussQuadrature: gausslegendre
# and eventually gausschebyshev, etc.
else
import ..FastGaussQuadrature
import ..FastGaussQuadrature: gausslegendre
end
using LinearAlgebra

Integrals.gausslegendre(n) = FastGaussQuadrature.gausslegendre(n)

function gauss_legendre(f, p, lb, ub, nodes, weights)
scale = (ub - lb) / 2
shift = (lb + ub) / 2
I = dot(weights, @. f(scale * nodes + shift, $Ref(p)))
return scale * I
end
function composite_gauss_legendre(f, p, lb, ub, nodes, weights, subintervals)
h = (ub - lb) / subintervals
I = zero(h)
for i in 1:subintervals
_lb = lb + (i - 1) * h
_ub = _lb + h
I += gauss_legendre(f, p, _lb, _ub, nodes, weights)
end
return I
end

function Integrals.__solvebp_call(prob::IntegralProblem, alg::Integrals.GaussLegendre{C},
sensealg, lb, ub, p;
reltol = nothing, abstol = nothing,
maxiters = nothing) where {C}
if isinplace(prob) || lb isa AbstractArray || ub isa AbstractArray
error("GaussLegendre only accepts one-dimensional quadrature problems.")
end
@assert prob.batch == 0
@assert prob.nout == 1
if C
val = composite_gauss_legendre(prob.f, prob.p, lb, ub,
alg.nodes, alg.weights, alg.subintervals)
else
val = gauss_legendre(prob.f, prob.p, lb, ub,
alg.nodes, alg.weights)
end
err = nothing
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
end
end
21 changes: 11 additions & 10 deletions src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,16 @@ function SciMLBase.solve(prob::IntegralProblem,
__solvebp(prob, alg, sensealg, prob.lb, prob.ub, prob.p; kwargs...)
end
# Throw error if alg is not provided, as defaults are not implemented.
SciMLBase.solve(::IntegralProblem) = throw(ArgumentError("""
No integration algorithm `alg` was supplied as the second positional argument.
Reccomended integration algorithms are:
For scalar functions: QuadGKJL()
For ≤ 8 dimensional vector functions: HCubatureJL()
For > 8 dimensional vector functions: MonteCarloIntegration.vegas(f, st, en, kwargs...)
See the docstrings of the different algorithms for more detail.
"""
))
function SciMLBase.solve(::IntegralProblem)
throw(ArgumentError("""
No integration algorithm `alg` was supplied as the second positional argument.
Reccomended integration algorithms are:
For scalar functions: QuadGKJL()
For ≤ 8 dimensional vector functions: HCubatureJL()
For > 8 dimensional vector functions: MonteCarloIntegration.vegas(f, st, en, kwargs...)
See the docstrings of the different algorithms for more detail.
"""))
end

# Give a layer to intercept with AD
__solvebp(args...; kwargs...) = __solvebp_call(args...; kwargs...)
Expand Down Expand Up @@ -188,5 +189,5 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p;
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
end

export QuadGKJL, HCubatureJL, VEGAS
export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre
end # module
40 changes: 40 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,43 @@ struct VEGAS <: SciMLBase.AbstractIntegralAlgorithm
debug::Bool
end
VEGAS(; nbins = 100, ncalls = 1000, debug = false) = VEGAS(nbins, ncalls, debug)

"""
GaussLegendre{C, N, W}

Struct for evaluating an integral via (composite) Gauss-Legendre quadrature.
The field `C` will be `true` if `subintervals > 1`, and `false` otherwise.

The fields `nodes::N` and `weights::W` are defined by
`nodes, weights = gausslegendre(n)` for a given number of nodes `n`.

The field `subintervals::Int64 = 1` (with default value `1`) defines the
number of intervals to partition the original interval of integration
`[a, b]` into, splitting it into `[xⱼ, xⱼ₊₁]` for `j = 1,…,subintervals`,
where `xⱼ = a + (j-1)h` and `h = (b-a)/subintervals`. Gauss-Legendre
quadrature is then applied on each subinterval. For example, if
`[a, b] = [-1, 1]` and `subintervals = 2`, then Gauss-Legendre
quadrature will be applied separately on `[-1, 0]` and `[0, 1]`,
summing the two results.
"""
struct GaussLegendre{C, N, W} <: SciMLBase.AbstractIntegralAlgorithm
nodes::N
weights::W
subintervals::Int64
function GaussLegendre(nodes::N, weights::W, subintervals = 1) where {N, W}
if subintervals > 1
return new{true, N, W}(nodes, weights, subintervals)
elseif subintervals == 1
return new{false, N, W}(nodes, weights, subintervals)
else
throw(ArgumentError("Cannot use a nonpositive number of subintervals."))
end
end
end
function gausslegendre end
function GaussLegendre(; n = 250, subintervals = 1, nodes = nothing, weights = nothing)
if isnothing(nodes) || isnothing(weights)
nodes, weights = gausslegendre(n)
end
return GaussLegendre(nodes, weights, subintervals)
end
1 change: 1 addition & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
function __init__()
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin include("../ext/IntegralsForwardDiffExt.jl") end
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/IntegralsZygoteExt.jl") end
@require FastGaussQuadrature="442a2c76-b920-505d-bb47-c5924d526838" begin include("../ext/IntegralsFastGaussQuadratureExt.jl") end
end
end
99 changes: 99 additions & 0 deletions test/gaussian_quadrature_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
using Integrals, Test, FastGaussQuadrature

#=
f = (x, p) -> x^3 * sin(5x)
n = 250
nodes, weights = gausslegendre(n)
I = gauss_legendre(f, nothing, -1, 1, nodes, weights)
@test I ≈ 2 / (625) * (69sin(5) - 95cos(5))
I = Integrals.composite_gauss_legendre(f, nothing, -1, 1, nodes, weights, 2)
@test I ≈ 2 / (625) * (69sin(5) - 95cos(5))

f = (x, p) -> (x + p) * abs(x)
n = 100
nodes, weights = gausslegendre(n)
I = Integrals.gauss_legendre(f, 0.0, -2, 2, nodes, weights)
Ic = Integrals.composite_gauss_legendre(f, 6, -2, 2, nodes, weights, 5)
@inferred Integrals.gauss_legendre(f, 0.0, -2, 2, nodes, weights)
@inferred Integrals.composite_gauss_legendre(f, 6, -2, 2, nodes, weights, 5)
@test I≈0.0 atol=1e-6
@test Ic≈24 rtol=1e-4
=#

alg = GaussLegendre()
n = 250
nd, wt = gausslegendre(n)
@test alg.nodes == nd
@test alg.weights == wt
@test alg.subintervals == 1
alg = GaussLegendre(n = 125, subintervals = 3)
n = 125
nd, wt = gausslegendre(n)
@test alg.nodes == nd
@test alg.weights == wt
@test alg.subintervals == 3
@test typeof(alg).parameters[1]
nd, wt = gausslegendre(275)
alg = GaussLegendre(nodes = nd, weights = wt)
@test !typeof(alg).parameters[1]
@test alg.nodes == nd
@test alg.weights == wt
@test alg.subintervals == 1
alg = GaussLegendre(nodes = nd, weights = wt, subintervals = 20)
@test typeof(alg).parameters[1]
@test alg.nodes == nd
@test alg.weights == wt
@test alg.subintervals == 20

f = (x, p) -> 5x + sin(x) - p * exp(x)
prob = IntegralProblem(f, -5, 3, 3.3)
alg = GaussLegendre()
sol = solve(prob, alg)
@test isnothing(sol.chi)
@test sol.alg === alg
@test sol.prob === prob
@test isnothing(sol.resid)
@test SciMLBase.successful_retcode(sol)
@test sol.u ≈ -exp(3) * 3.3 + 3.3 / exp(5) - 40 + cos(5) - cos(3)
alg = GaussLegendre(subintervals = 7)
sol = solve(prob, alg)
@test sol.u ≈ -exp(3) * 3.3 + 3.3 / exp(5) - 40 + cos(5) - cos(3)

f = (x, p) -> exp(-x^2)
prob = IntegralProblem(f, 0.0, Inf)
alg = GaussLegendre()
sol = solve(prob, alg)
@test sol.u ≈ sqrt(π)/2
alg = GaussLegendre(subintervals=1)
@test sol.u ≈ sqrt(π)/2
alg = GaussLegendre(subintervals=17)
@test sol.u ≈ sqrt(π)/2

prob = IntegralProblem(f, -Inf, Inf)
alg = GaussLegendre()
sol = solve(prob, alg)
@test sol.u ≈ sqrt(π)
alg = GaussLegendre(subintervals=1)
@test sol.u ≈ sqrt(π)
alg = GaussLegendre(subintervals=17)
@test sol.u ≈ sqrt(π)

prob = IntegralProblem(f, -Inf, 0.0)
alg = GaussLegendre()
sol = solve(prob, alg)
@test sol.u ≈ sqrt(π)/2
alg = GaussLegendre(subintervals=1)
@test sol.u ≈ sqrt(π)/2
alg = GaussLegendre(subintervals=17)
@test sol.u ≈ sqrt(π)/2

# Make sure broadcasting correctly handles the argument p
f = (x, p) -> 1 + x + x^p[1] - cos(x*p[2]) + exp(x)*p[3]
p = [0.3, 1.3, -0.5]
prob = IntegralProblem(f, 2, 6.3, p)
alg = GaussLegendre()
sol = solve(prob, alg)
@test sol.u ≈ -240.25235266303063249920743158729
alg = GaussLegendre(n = 500, subintervals = 17)
sol = solve(prob, alg)
@test sol.u ≈ -240.25235266303063249920743158729
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ dev_subpkg("IntegralsCubature")
@time @safetestset "Interface Tests" begin include("interface_tests.jl") end
@time @safetestset "Derivative Tests" begin include("derivative_tests.jl") end
@time @safetestset "Infinite Integral Tests" begin include("inf_integral_tests.jl") end
@time @safetestset "Gaussian Quadrature Tests" begin include("gaussian_quadrature_tests.jl") end