From ebfd4a6cd86b5f88bb758d5147061ad68fc7772c Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Sun, 31 Dec 2023 19:16:35 -0500 Subject: [PATCH 01/14] Automatic convertion for gsl_function Some functions are explicitly forbidden from using these since they store the reference of gsl_function across multiple calls. --- src/gen/direct_wrappers/gsl_chebyshev_h.jl | 2 +- src/gen/direct_wrappers/gsl_deriv_h.jl | 6 +-- src/gen/direct_wrappers/gsl_diff_h.jl | 6 +-- src/gen/direct_wrappers/gsl_integration_h.jl | 46 ++++++++++---------- src/gen/direct_wrappers/gsl_min_h.jl | 6 +-- src/gen/direct_wrappers/gsl_roots_h.jl | 2 +- src/manual_wrappers.jl | 28 ++++++++++++ test/numdiff.jl | 7 ++- test/quadrature.jl | 20 +++++++++ 9 files changed, 87 insertions(+), 36 deletions(-) diff --git a/src/gen/direct_wrappers/gsl_chebyshev_h.jl b/src/gen/direct_wrappers/gsl_chebyshev_h.jl index ea58fab..407e431 100644 --- a/src/gen/direct_wrappers/gsl_chebyshev_h.jl +++ b/src/gen/direct_wrappers/gsl_chebyshev_h.jl @@ -57,7 +57,7 @@ GSL documentation: > and requires $n$ function evaluations. """ -function cheb_init(cs, func, a, b) +function cheb_init(cs, func::gsl_function, a, b) ccall((:gsl_cheb_init, libgsl), Cint, (Ref{gsl_cheb_series}, Ref{gsl_function}, Cdouble, Cdouble), cs, func, a, b) end diff --git a/src/gen/direct_wrappers/gsl_deriv_h.jl b/src/gen/direct_wrappers/gsl_deriv_h.jl index a6d6e25..545c76f 100644 --- a/src/gen/direct_wrappers/gsl_deriv_h.jl +++ b/src/gen/direct_wrappers/gsl_deriv_h.jl @@ -32,7 +32,7 @@ GSL documentation: > actually used. """ -function deriv_central(f, x, h, result, abserr) +function deriv_central(f::F, x, h, result, abserr) where F ccall((:gsl_deriv_central, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}), f, x, h, result, abserr) end @@ -58,7 +58,7 @@ GSL documentation: > negative step-size. """ -function deriv_backward(f, x, h, result, abserr) +function deriv_backward(f::F, x, h, result, abserr) where F ccall((:gsl_deriv_backward, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}), f, x, h, result, abserr) end @@ -89,7 +89,7 @@ GSL documentation: > $x+h/2$, $x+h$. """ -function deriv_forward(f, x, h, result, abserr) +function deriv_forward(f::F, x, h, result, abserr) where F ccall((:gsl_deriv_forward, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}), f, x, h, result, abserr) end diff --git a/src/gen/direct_wrappers/gsl_diff_h.jl b/src/gen/direct_wrappers/gsl_diff_h.jl index 97f99cf..6b885a4 100644 --- a/src/gen/direct_wrappers/gsl_diff_h.jl +++ b/src/gen/direct_wrappers/gsl_diff_h.jl @@ -12,7 +12,7 @@ C signature: `int gsl_diff_central (const gsl_function *f, double x, double *result, double *abserr)` """ -function diff_central(f, x, result, abserr) +function diff_central(f::F, x, result, abserr) where F ccall((:gsl_diff_central, libgsl), Cint, (Ref{gsl_function}, Cdouble, Ref{Cdouble}, Ref{Cdouble}), f, x, result, abserr) end @@ -22,7 +22,7 @@ end C signature: `int gsl_diff_backward (const gsl_function *f, double x, double *result, double *abserr)` """ -function diff_backward(f, x, result, abserr) +function diff_backward(f::F, x, result, abserr) where F ccall((:gsl_diff_backward, libgsl), Cint, (Ref{gsl_function}, Cdouble, Ref{Cdouble}, Ref{Cdouble}), f, x, result, abserr) end @@ -32,7 +32,7 @@ end C signature: `int gsl_diff_forward (const gsl_function *f, double x, double *result, double *abserr)` """ -function diff_forward(f, x, result, abserr) +function diff_forward(f::F, x, result, abserr) where F ccall((:gsl_diff_forward, libgsl), Cint, (Ref{gsl_function}, Cdouble, Ref{Cdouble}, Ref{Cdouble}), f, x, result, abserr) end diff --git a/src/gen/direct_wrappers/gsl_integration_h.jl b/src/gen/direct_wrappers/gsl_integration_h.jl index bbca63a..48624b2 100644 --- a/src/gen/direct_wrappers/gsl_integration_h.jl +++ b/src/gen/direct_wrappers/gsl_integration_h.jl @@ -253,7 +253,7 @@ end C signature: `void gsl_integration_qk15 (const gsl_function * f, double a, double b, double *result, double *abserr, double *resabs, double *resasc)` """ -function integration_qk15(f, a, b, result, abserr, resabs, resasc) +function integration_qk15(f::F, a, b, result, abserr, resabs, resasc) where F ccall((:gsl_integration_qk15, libgsl), Cvoid, (Ref{gsl_function}, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}), f, a, b, result, abserr, resabs, resasc) end @@ -263,7 +263,7 @@ end C signature: `void gsl_integration_qk21 (const gsl_function * f, double a, double b, double *result, double *abserr, double *resabs, double *resasc)` """ -function integration_qk21(f, a, b, result, abserr, resabs, resasc) +function integration_qk21(f::F, a, b, result, abserr, resabs, resasc) where F ccall((:gsl_integration_qk21, libgsl), Cvoid, (Ref{gsl_function}, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}), f, a, b, result, abserr, resabs, resasc) end @@ -273,7 +273,7 @@ end C signature: `void gsl_integration_qk31 (const gsl_function * f, double a, double b, double *result, double *abserr, double *resabs, double *resasc)` """ -function integration_qk31(f, a, b, result, abserr, resabs, resasc) +function integration_qk31(f::F, a, b, result, abserr, resabs, resasc) where F ccall((:gsl_integration_qk31, libgsl), Cvoid, (Ref{gsl_function}, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}), f, a, b, result, abserr, resabs, resasc) end @@ -283,7 +283,7 @@ end C signature: `void gsl_integration_qk41 (const gsl_function * f, double a, double b, double *result, double *abserr, double *resabs, double *resasc)` """ -function integration_qk41(f, a, b, result, abserr, resabs, resasc) +function integration_qk41(f::F, a, b, result, abserr, resabs, resasc) where F ccall((:gsl_integration_qk41, libgsl), Cvoid, (Ref{gsl_function}, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}), f, a, b, result, abserr, resabs, resasc) end @@ -293,7 +293,7 @@ end C signature: `void gsl_integration_qk51 (const gsl_function * f, double a, double b, double *result, double *abserr, double *resabs, double *resasc)` """ -function integration_qk51(f, a, b, result, abserr, resabs, resasc) +function integration_qk51(f::F, a, b, result, abserr, resabs, resasc) where F ccall((:gsl_integration_qk51, libgsl), Cvoid, (Ref{gsl_function}, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}), f, a, b, result, abserr, resabs, resasc) end @@ -303,7 +303,7 @@ end C signature: `void gsl_integration_qk61 (const gsl_function * f, double a, double b, double *result, double *abserr, double *resabs, double *resasc)` """ -function integration_qk61(f, a, b, result, abserr, resabs, resasc) +function integration_qk61(f::F, a, b, result, abserr, resabs, resasc) where F ccall((:gsl_integration_qk61, libgsl), Cvoid, (Ref{gsl_function}, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}), f, a, b, result, abserr, resabs, resasc) end @@ -313,7 +313,7 @@ end C signature: `void gsl_integration_qcheb (gsl_function * f, double a, double b, double *cheb12, double *cheb24)` """ -function integration_qcheb(f, a, b, cheb12, cheb24) +function integration_qcheb(f::F, a, b, cheb12, cheb24) where F ccall((:gsl_integration_qcheb, libgsl), Cvoid, (Ref{gsl_function}, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}), f, a, b, cheb12, cheb24) end @@ -323,7 +323,7 @@ end C signature: `void gsl_integration_qk (const int n, const double xgk[], const double wg[], const double wgk[], double fv1[], double fv2[], const gsl_function *f, double a, double b, double * result, double * abserr, double * resabs, double * resasc)` """ -function integration_qk(n, xgk, wg, wgk, fv1, fv2, f, a, b, result, abserr, resabs, resasc) +function integration_qk(n, xgk, wg, wgk, fv1, fv2, f::F, a, b, result, abserr, resabs, resasc) where F ccall((:gsl_integration_qk, libgsl), Cvoid, (Cint, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{gsl_function}, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}), n, xgk, wg, wgk, fv1, fv2, f, a, b, result, abserr, resabs, resasc) end @@ -348,7 +348,7 @@ GSL documentation: > of function evaluations. """ -function integration_qng(f, a, b, epsabs, epsrel, result, abserr, neval) +function integration_qng(f::F, a, b, epsabs, epsrel, result, abserr, neval) where F ccall((:gsl_integration_qng, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Cdouble, Cdouble, Ref{Cdouble}, Ref{Cdouble}, Ref{Csize_t}), f, a, b, epsabs, epsrel, result, abserr, neval) end @@ -394,7 +394,7 @@ GSL documentation: > allocated size of the workspace. """ -function integration_qag(f, a, b, epsabs, epsrel, limit, key, workspace, result, abserr) +function integration_qag(f::F, a, b, epsabs, epsrel, limit, key, workspace, result, abserr) where F ccall((:gsl_integration_qag, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Cdouble, Cdouble, Csize_t, Cint, Ref{gsl_integration_workspace}, Ref{Cdouble}, Ref{Cdouble}), f, a, b, epsabs, epsrel, limit, key, workspace, result, abserr) end @@ -420,7 +420,7 @@ GSL documentation: > In this case a lower-order rule is more efficient. """ -function integration_qagi(f, epsabs, epsrel, limit, workspace, result, abserr) +function integration_qagi(f::F, epsabs, epsrel, limit, workspace, result, abserr) where F ccall((:gsl_integration_qagi, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Csize_t, Ref{gsl_integration_workspace}, Ref{Cdouble}, Ref{Cdouble}), f, epsabs, epsrel, limit, workspace, result, abserr) end @@ -443,7 +443,7 @@ GSL documentation: > and then integrated using the QAGS algorithm. """ -function integration_qagiu(f, a, epsabs, epsrel, limit, workspace, result, abserr) +function integration_qagiu(f::F, a, epsabs, epsrel, limit, workspace, result, abserr) where F ccall((:gsl_integration_qagiu, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Cdouble, Csize_t, Ref{gsl_integration_workspace}, Ref{Cdouble}, Ref{Cdouble}), f, a, epsabs, epsrel, limit, workspace, result, abserr) end @@ -466,7 +466,7 @@ GSL documentation: > and then integrated using the QAGS algorithm. """ -function integration_qagil(f, b, epsabs, epsrel, limit, workspace, result, abserr) +function integration_qagil(f::F, b, epsabs, epsrel, limit, workspace, result, abserr) where F ccall((:gsl_integration_qagil, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Cdouble, Csize_t, Ref{gsl_integration_workspace}, Ref{Cdouble}, Ref{Cdouble}), f, b, epsabs, epsrel, limit, workspace, result, abserr) end @@ -493,7 +493,7 @@ GSL documentation: > which may not exceed the allocated size of the workspace. """ -function integration_qags(f, a, b, epsabs, epsrel, limit, workspace, result, abserr) +function integration_qags(f::F, a, b, epsabs, epsrel, limit, workspace, result, abserr) where F ccall((:gsl_integration_qags, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Cdouble, Cdouble, Csize_t, Ref{gsl_integration_workspace}, Ref{Cdouble}, Ref{Cdouble}), f, a, b, epsabs, epsrel, limit, workspace, result, abserr) end @@ -527,7 +527,7 @@ GSL documentation: > region then this routine will be faster than `gsl_integration_qags`. """ -function integration_qagp(f, pts, npts, epsabs, epsrel, limit, workspace, result, abserr) +function integration_qagp(f::F, pts, npts, epsabs, epsrel, limit, workspace, result, abserr) where F ccall((:gsl_integration_qagp, libgsl), Cint, (Ref{gsl_function}, Ref{Cdouble}, Csize_t, Cdouble, Cdouble, Csize_t, Ref{gsl_integration_workspace}, Ref{Cdouble}, Ref{Cdouble}), f, pts, npts, epsabs, epsrel, limit, workspace, result, abserr) end @@ -566,7 +566,7 @@ GSL documentation: > ordinary 15-point Gauss-Kronrod integration rule. """ -function integration_qawc(f, a, b, c, epsabs, epsrel, limit, workspace, result, abserr) +function integration_qawc(f::F, a, b, c, epsabs, epsrel, limit, workspace, result, abserr) where F ccall((:gsl_integration_qawc, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Cdouble, Cdouble, Cdouble, Csize_t, Ref{gsl_integration_workspace}, Ref{Cdouble}, Ref{Cdouble}), f, a, b, c, epsabs, epsrel, limit, workspace, result, abserr) end @@ -595,7 +595,7 @@ GSL documentation: > Gauss-Kronrod integration rule is used. """ -function integration_qaws(f, a, b, t, epsabs, epsrel, limit, workspace, result, abserr) +function integration_qaws(f::F, a, b, t, epsabs, epsrel, limit, workspace, result, abserr) where F ccall((:gsl_integration_qaws, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Ref{gsl_integration_qaws_table}, Cdouble, Cdouble, Csize_t, Ref{gsl_integration_workspace}, Ref{Cdouble}, Ref{Cdouble}), f, a, b, t, epsabs, epsrel, limit, workspace, result, abserr) end @@ -645,7 +645,7 @@ GSL documentation: > integration. """ -function integration_qawo(f, a, epsabs, epsrel, limit, workspace, wf, result, abserr) +function integration_qawo(f::F, a, epsabs, epsrel, limit, workspace, wf, result, abserr) where F ccall((:gsl_integration_qawo, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Cdouble, Csize_t, Ref{gsl_integration_workspace}, Ref{gsl_integration_qawo_table}, Ref{Cdouble}, Ref{Cdouble}), f, a, epsabs, epsrel, limit, workspace, wf, result, abserr) end @@ -737,7 +737,7 @@ GSL documentation: > `cycle_workspace` as workspace for the QAWO algorithm. """ -function integration_qawf(f, a, epsabs, limit, workspace, cycle_workspace, wf, result, abserr) +function integration_qawf(f::F, a, epsabs, limit, workspace, cycle_workspace, wf, result, abserr) where F ccall((:gsl_integration_qawf, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Csize_t, Ref{gsl_integration_workspace}, Ref{gsl_integration_workspace}, Ref{gsl_integration_qawo_table}, Ref{Cdouble}, Ref{Cdouble}), f, a, epsabs, limit, workspace, cycle_workspace, wf, result, abserr) end @@ -793,7 +793,7 @@ GSL documentation: > table `t` and returns the result. """ -function integration_glfixed(f, a, b, t) +function integration_glfixed(f::F, a, b, t) where F ccall((:gsl_integration_glfixed, libgsl), Cdouble, (Ref{gsl_function}, Cdouble, Cdouble, Ref{gsl_integration_glfixed_table}), f, a, b, t) end @@ -890,7 +890,7 @@ GSL documentation: > set to `NULL`. """ -function integration_cquad(f, a, b, epsabs, epsrel, ws, result, abserr, nevals) +function integration_cquad(f::F, a, b, epsabs, epsrel, ws, result, abserr, nevals) where F ccall((:gsl_integration_cquad, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Cdouble, Cdouble, Ref{gsl_integration_cquad_workspace}, Ref{Cdouble}, Ref{Cdouble}, Ref{Csize_t}), f, a, b, epsabs, epsrel, ws, result, abserr, nevals) end @@ -957,7 +957,7 @@ GSL documentation: > `neval`. """ -function integration_romberg(f, a, b, epsabs, epsrel, result, neval, w) +function integration_romberg(f::F, a, b, epsabs, epsrel, result, neval, w) where F ccall((:gsl_integration_romberg, libgsl), Cint, (Ref{gsl_function}, Cdouble, Cdouble, Cdouble, Cdouble, Ref{Cdouble}, Ref{Csize_t}, Ref{gsl_integration_romberg_workspace}), f, a, b, epsabs, epsrel, result, neval, w) end @@ -1116,7 +1116,7 @@ GSL documentation: > approximated as """ -function integration_fixed(func, result, w) +function integration_fixed(func::F, result, w) where F ccall((:gsl_integration_fixed, libgsl), Cint, (Ref{gsl_function}, Ref{Cdouble}, Ref{gsl_integration_fixed_workspace}), func, result, w) end diff --git a/src/gen/direct_wrappers/gsl_min_h.jl b/src/gen/direct_wrappers/gsl_min_h.jl index d2bb693..f7ed0a2 100644 --- a/src/gen/direct_wrappers/gsl_min_h.jl +++ b/src/gen/direct_wrappers/gsl_min_h.jl @@ -67,7 +67,7 @@ GSL documentation: > returns an error code of `GSL_EINVAL`. """ -function min_fminimizer_set(s, f, x_minimum, x_lower, x_upper) +function min_fminimizer_set(s, f::gsl_function, x_minimum, x_lower, x_upper) ccall((:gsl_min_fminimizer_set, libgsl), Cint, (Ref{gsl_min_fminimizer}, Ref{gsl_function}, Cdouble, Cdouble, Cdouble), s, f, x_minimum, x_lower, x_upper) end @@ -86,7 +86,7 @@ GSL documentation: > `f(x_minimum)`, `f(x_lower)` and `f(x_upper)`. """ -function min_fminimizer_set_with_values(s, f, x_minimum, f_minimum, x_lower, f_lower, x_upper, f_upper) +function min_fminimizer_set_with_values(s, f::gsl_function, x_minimum, f_minimum, x_lower, f_lower, x_upper, f_upper) ccall((:gsl_min_fminimizer_set_with_values, libgsl), Cint, (Ref{gsl_min_fminimizer}, Ref{gsl_function}, Cdouble, Cdouble, Cdouble, Cdouble, Cdouble, Cdouble), s, f, x_minimum, f_minimum, x_lower, f_lower, x_upper, f_upper) end @@ -301,7 +301,7 @@ end C signature: `int gsl_min_find_bracket(gsl_function *f,double *x_minimum,double * f_minimum, double *x_lower, double * f_lower, double *x_upper, double * f_upper, size_t eval_max)` """ -function min_find_bracket(f, x_minimum, f_minimum, x_lower, f_lower, x_upper, f_upper, eval_max) +function min_find_bracket(f::F, x_minimum, f_minimum, x_lower, f_lower, x_upper, f_upper, eval_max) where F ccall((:gsl_min_find_bracket, libgsl), Cint, (Ref{gsl_function}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Ref{Cdouble}, Csize_t), f, x_minimum, f_minimum, x_lower, f_lower, x_upper, f_upper, eval_max) end diff --git a/src/gen/direct_wrappers/gsl_roots_h.jl b/src/gen/direct_wrappers/gsl_roots_h.jl index a835e33..68a497b 100644 --- a/src/gen/direct_wrappers/gsl_roots_h.jl +++ b/src/gen/direct_wrappers/gsl_roots_h.jl @@ -66,7 +66,7 @@ GSL documentation: > `x_upper`\]. """ -function root_fsolver_set(s, f, x_lower, x_upper) +function root_fsolver_set(s, f::gsl_function, x_lower, x_upper) ccall((:gsl_root_fsolver_set, libgsl), Cint, (Ref{gsl_root_fsolver}, Ref{gsl_function}, Cdouble, Cdouble), s, f, x_lower, x_upper) end diff --git a/src/manual_wrappers.jl b/src/manual_wrappers.jl index c73e789..c34fc56 100644 --- a/src/manual_wrappers.jl +++ b/src/manual_wrappers.jl @@ -30,6 +30,34 @@ end ## Root finding +gsl_function_helper(x::Cdouble, fn)::Cdouble = fn(x) + +# The following code relies on `gsl_function` being a mutable type +# (such that we can call `pointer_from_objref` on it) to simplify the object structure +# a little bit and avoid hitting some limitation of the allocation optimizer. +@assert ismutable(gsl_function(C_NULL, C_NULL)) + +function wrap_gsl_function(fn::F) where F + # We need to allocate the `gsl_function` here to be kept alive by ccall + # This require us to create the pointer to the function and the callable object + param_ref = Base.cconvert(Ref{F}, fn) + fptr = @cfunction(gsl_function_helper, Cdouble, (Cdouble, Ref{F})) + param_ptr = Base.unsafe_convert(Ref{F}, param_ref) + gsl_func = gsl_function(fptr, param_ptr) + return gsl_func, param_ref +end + +function Base.cconvert(::Type{Ref{gsl_function}}, fn::F) where F + return wrap_gsl_function(fn) +end +function Base.unsafe_convert(::Type{Ref{gsl_function}}, + (gsl_func,)::Tuple{gsl_function,Any}) + return pointer_from_objref(gsl_func) +end + +Base.cconvert(::Type{Ref{gsl_function}}, gslf::gsl_function) = + convert(Ref{gsl_function}, gslf) + # Macros for easier creation of gsl_function and gsl_function_fdf structs export @gsl_function, @gsl_function_fdf diff --git a/test/numdiff.jl b/test/numdiff.jl index eb1c445..45d0bb1 100644 --- a/test/numdiff.jl +++ b/test/numdiff.jl @@ -1,7 +1,7 @@ using GSL using Test -func(x) = x^3 +func(x) = x^3 @testset "Numerical differentiation" begin x = 1.0 @@ -14,6 +14,9 @@ func(x) = x^3 df,ddf = Cdouble[0], Cdouble[0] deriv(@gsl_function(func), x, h, df, ddf) @test abs(df_dx - df[]) <= ddf[] <= d2f_dx2*h/2 + 2eps()/h + + df,ddf = Cdouble[0], Cdouble[0] + deriv(func, x, h, df, ddf) + @test abs(df_dx - df[]) <= ddf[] <= d2f_dx2*h/2 + 2eps()/h end end - diff --git a/test/quadrature.jl b/test/quadrature.jl index b428412..cb29702 100644 --- a/test/quadrature.jl +++ b/test/quadrature.jl @@ -19,3 +19,23 @@ fquad = x -> x^1.5 integration_cquad_workspace_free(ws) end + +@testset "Quadrature (closure)" begin + # Make sure this is actually a closure + fquad = let n = 1.5 + x -> x^n + end + ws_size = 100 + ws = integration_cquad_workspace_alloc(ws_size) + + a = 0 + b = 1 + result = Cdouble[0] + abserr = Cdouble[0] + nevals = Csize_t[0] + integration_cquad(fquad, a, b, 1e-10, 1e-10, ws, result, abserr, nevals) + + @test abs(result[] - 1/2.5) < abserr[] + + integration_cquad_workspace_free(ws) +end From cbddc45c32907a66580e101fabf59f424328bf73 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 4 Jan 2024 10:51:44 -0500 Subject: [PATCH 02/14] Automatic conversion for gsl_function_fdf --- src/gen/direct_wrappers/gsl_roots_h.jl | 2 +- src/manual_wrappers.jl | 45 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/gen/direct_wrappers/gsl_roots_h.jl b/src/gen/direct_wrappers/gsl_roots_h.jl index 68a497b..0fed5a5 100644 --- a/src/gen/direct_wrappers/gsl_roots_h.jl +++ b/src/gen/direct_wrappers/gsl_roots_h.jl @@ -221,7 +221,7 @@ GSL documentation: > use the function and derivative `fdf` and the initial guess `root`. """ -function root_fdfsolver_set(s, fdf, root) +function root_fdfsolver_set(s, fdf::gsl_function_fdf, root) ccall((:gsl_root_fdfsolver_set, libgsl), Cint, (Ref{gsl_root_fdfsolver}, Ref{gsl_function_fdf}, Cdouble), s, fdf, root) end diff --git a/src/manual_wrappers.jl b/src/manual_wrappers.jl index c34fc56..c06b76d 100644 --- a/src/manual_wrappers.jl +++ b/src/manual_wrappers.jl @@ -58,6 +58,51 @@ end Base.cconvert(::Type{Ref{gsl_function}}, gslf::gsl_function) = convert(Ref{gsl_function}, gslf) +# The following code relies on `gsl_function_fdf` being a mutable type +@assert ismutable(gsl_function_fdf(C_NULL, C_NULL, C_NULL, C_NULL)) + +function gsl_function_f_helper(x::Cdouble, (f,))::Cdouble + return f(x) +end +function gsl_function_df_helper(x::Cdouble, (f, df))::Cdouble + return df(x) +end +function gsl_function_fdf_helper(x::Cdouble, (f, df)::NTuple{2,Any}, pf, pdf) + unsafe_store!(pf, f(x)) + unsafe_store!(pdf, df(x)) + return +end +function gsl_function_fdf_helper(x::Cdouble, (f, df, fdf)::NTuple{3,Any}, pf, pdf) + f, df = fdf(x) + unsafe_store!(pf, f) + unsafe_store!(pdf, df) + return +end + +function wrap_gsl_function_fdf(fn::FDF) where FDF + param_ref = Base.cconvert(Ref{FDF}, fn) + fptr = @cfunction(gsl_function_f_helper, Cdouble, (Cdouble, Ref{FDF})) + dfptr = @cfunction(gsl_function_df_helper, Cdouble, (Cdouble, Ref{FDF})) + fdfptr = @cfunction(gsl_function_fdf_helper, Cvoid, + (Cdouble, Ref{FDF}, Ptr{Cdouble}, Ptr{Cdouble})) + param_ptr = Base.unsafe_convert(Ref{FDF}, param_ref) + gsl_func = gsl_function_fdf(fptr, dfptr, fdfptr, param_ptr) + return gsl_func, param_ref +end + +# Do not define these since there's no safe way to use these at the moment +# function Base.cconvert(::Type{Ref{gsl_function_fdf}}, +# fn::Union{NTuple{2,Any},NTuple{3,Any}}) +# return wrap_gsl_function_fdf(fn) +# end +# function Base.unsafe_convert(::Type{Ref{gsl_function_fdf}}, +# (gsl_func,)::Tuple{gsl_function_fdf,Any}) +# return pointer_from_objref(gsl_func) +# end + +# Base.cconvert(::Type{Ref{gsl_function_fdf}}, gslf::gsl_function_fdf) = +# convert(Ref{gsl_function_fdf}, gslf) + # Macros for easier creation of gsl_function and gsl_function_fdf structs export @gsl_function, @gsl_function_fdf From 0045776941fc0afe3be91a80d09a4a2d48c32cf0 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 4 Jan 2024 11:14:36 -0500 Subject: [PATCH 03/14] Automatic conversion for gsl_multiroot_function --- src/gen/direct_wrappers/gsl_multiroots_h.jl | 4 +-- src/manual_wrappers.jl | 29 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/gen/direct_wrappers/gsl_multiroots_h.jl b/src/gen/direct_wrappers/gsl_multiroots_h.jl index 0f9a148..934aaa1 100644 --- a/src/gen/direct_wrappers/gsl_multiroots_h.jl +++ b/src/gen/direct_wrappers/gsl_multiroots_h.jl @@ -12,7 +12,7 @@ C signature: `int gsl_multiroot_fdjacobian (gsl_multiroot_function * F, const gsl_vector * x, const gsl_vector * f, double epsrel, gsl_matrix * jacobian)` """ -function multiroot_fdjacobian(F, x, f, epsrel, jacobian) +function multiroot_fdjacobian(F::Fn, x, f, epsrel, jacobian) where Fn ccall((:gsl_multiroot_fdjacobian, libgsl), Cint, (Ref{gsl_multiroot_function}, Ref{gsl_vector}, Ref{gsl_vector}, Cdouble, Ref{gsl_matrix}), F, x, f, epsrel, jacobian) end @@ -81,7 +81,7 @@ GSL documentation: > is not modified by subsequent iterations. """ -function multiroot_fsolver_set(s, f, x) +function multiroot_fsolver_set(s, f::gsl_multiroot_function, x) ccall((:gsl_multiroot_fsolver_set, libgsl), Cint, (Ref{gsl_multiroot_fsolver}, Ref{gsl_multiroot_function}, Ref{gsl_vector}), s, f, x) end diff --git a/src/manual_wrappers.jl b/src/manual_wrappers.jl index c06b76d..502935a 100644 --- a/src/manual_wrappers.jl +++ b/src/manual_wrappers.jl @@ -175,6 +175,35 @@ macro gsl_function_fdf(f, df) ) end +# The following code relies on `gsl_multiroot_function` being a mutable type +@assert ismutable(gsl_multiroot_function(C_NULL, 0, C_NULL)) + +function gsl_multiroot_function_helper(x_vec, f, y_vec) + x = wrap_gsl_vector(x_vec) + y = wrap_gsl_vector(y_vec) + f(x, y) + return Cint(GSL.GSL_SUCCESS) +end + +function wrap_gsl_multiroot_function(fn::F, n) where F + param_ref = Base.cconvert(Ref{F}, fn) + fptr = @cfunction(gsl_multiroot_function_helper, + Cint, (Ptr{gsl_vector}, Ref{F}, Ptr{gsl_vector})) + param_ptr = Base.unsafe_convert(Ref{F}, param_ref) + gsl_func = gsl_multiroot_function(fptr, n, param_ptr) + return gsl_func, param_ref +end + +function Base.cconvert(::Type{Ref{gsl_multiroot_function}}, (fn, n)::NTuple{2,Any}) + return wrap_gsl_multiroot_function(fn, n) +end +function Base.unsafe_convert(::Type{Ref{gsl_multiroot_function}}, + (gsl_func,)::Tuple{gsl_multiroot_function,Any}) + return pointer_from_objref(gsl_func) +end + +Base.cconvert(::Type{Ref{gsl_multiroot_function}}, gslf::gsl_multiroot_function) = + convert(Ref{gsl_multiroot_function}, gslf) export @gsl_multiroot_function, @gsl_multiroot_function_fdf From 93f33b3a4882554caf43fabbb7d53a264de17a21 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 4 Jan 2024 11:25:37 -0500 Subject: [PATCH 04/14] Automatic conversion for gsl_multiroot_function_fdf --- src/gen/direct_wrappers/gsl_multiroots_h.jl | 2 +- src/manual_wrappers.jl | 62 +++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/gen/direct_wrappers/gsl_multiroots_h.jl b/src/gen/direct_wrappers/gsl_multiroots_h.jl index 934aaa1..ba95e01 100644 --- a/src/gen/direct_wrappers/gsl_multiroots_h.jl +++ b/src/gen/direct_wrappers/gsl_multiroots_h.jl @@ -242,7 +242,7 @@ end C signature: `int gsl_multiroot_fdfsolver_set (gsl_multiroot_fdfsolver * s, gsl_multiroot_function_fdf * fdf, const gsl_vector * x)` """ -function multiroot_fdfsolver_set(s, fdf, x) +function multiroot_fdfsolver_set(s, fdf::gsl_multiroot_function_fdf, x) ccall((:gsl_multiroot_fdfsolver_set, libgsl), Cint, (Ref{gsl_multiroot_fdfsolver}, Ref{gsl_multiroot_function_fdf}, Ref{gsl_vector}), s, fdf, x) end diff --git a/src/manual_wrappers.jl b/src/manual_wrappers.jl index 502935a..575389f 100644 --- a/src/manual_wrappers.jl +++ b/src/manual_wrappers.jl @@ -205,6 +205,68 @@ end Base.cconvert(::Type{Ref{gsl_multiroot_function}}, gslf::gsl_multiroot_function) = convert(Ref{gsl_multiroot_function}, gslf) +# The following code relies on `gsl_multiroot_function_fdf` being a mutable type +@assert ismutable(gsl_multiroot_function_fdf(C_NULL, C_NULL, C_NULL, 0, C_NULL)) + +function gsl_multiroot_function_f_helper(x_vec, (f,), f_vec) + xarr = wrap_gsl_vector(x_vec) + farr = wrap_gsl_vector(f_vec) + f(xarr, farr) + return Cint(GSL.GSL_SUCCESS) +end +function gsl_multiroot_function_df_helper(x_vec, (f, df), J_vec) + xarr = wrap_gsl_vector(x_vec) + Jmat = wrap_gsl_matrix(J_vec) + df(xarr, Jmat) + return Cint(GSL.GSL_SUCCESS) +end +function gsl_multiroot_function_fdf_helper(x_vec, (f, df)::NTuple{2,Any}, f_vec, J_vec) + xarr = wrap_gsl_vector(x_vec) + farr = wrap_gsl_vector(f_vec) + Jmat = wrap_gsl_matrix(J_vec) + f(xarr, farr) + df(xarr, Jmat) + return Cint(GSL.GSL_SUCCESS) +end +function gsl_multiroot_function_fdf_helper(x_vec, (f, df, fdf)::NTuple{3,Any}, + f_vec, J_vec) + xarr = wrap_gsl_vector(x_vec) + farr = wrap_gsl_vector(f_vec) + Jmat = wrap_gsl_matrix(J_vec) + fdf(xarr, farr, Jmat) + return Cint(GSL.GSL_SUCCESS) +end + +function wrap_gsl_multiroot_function_fdf(fn::FDF, n) where FDF + param_ref = Base.cconvert(Ref{FDF}, fn) + fptr = @cfunction(gsl_multiroot_function_f_helper, + Cint, (Ptr{gsl_vector}, Ref{FDF}, Ptr{gsl_vector})) + dfptr = @cfunction(gsl_multiroot_function_df_helper, + Cint, (Ptr{gsl_vector}, Ref{FDF}, Ptr{gsl_matrix})) + fdfptr = @cfunction(gsl_multiroot_function_fdf_helper, + Cint, (Ptr{gsl_vector}, Ref{FDF}, Ptr{gsl_vector}, Ptr{gsl_matrix})) + param_ptr = Base.unsafe_convert(Ref{FDF}, param_ref) + gsl_func = gsl_multiroot_function_fdf(fptr, dfptr, fdfptr, n, param_ptr) + return gsl_func, param_ref +end + +# Do not define these since there's no safe way to use these at the moment +# function Base.cconvert(::Type{Ref{gsl_multiroot_function_fdf}}, +# (f, df, n)::NTuple{3,Any}) +# return wrap_gsl_multiroot_function_fdf((f, df), n) +# end +# function Base.cconvert(::Type{Ref{gsl_multiroot_function_fdf}}, +# (f, df, fdf, n)::NTuple{4,Any}) +# return wrap_gsl_multiroot_function_fdf((f, df, fdf), n) +# end +# function Base.unsafe_convert(::Type{Ref{gsl_multiroot_function_fdf}}, +# (gsl_func,)::Tuple{gsl_multiroot_function_fdf,Any}) +# return pointer_from_objref(gsl_func) +# end + +# Base.cconvert(::Type{Ref{gsl_multiroot_function_fdf}}, gslf::gsl_multiroot_function_fdf) = +# convert(Ref{gsl_multiroot_function_fdf}, gslf) + export @gsl_multiroot_function, @gsl_multiroot_function_fdf """ From 8ebf3c0b201f0dfccf2a6388abc19e7916e66d4e Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 4 Jan 2024 15:35:32 -0500 Subject: [PATCH 05/14] Add Dependency on REPL for Docs.doc --- Project.toml | 2 ++ src/GSL.jl | 1 + 2 files changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index 8dcca1b..586f7f2 100644 --- a/Project.toml +++ b/Project.toml @@ -6,10 +6,12 @@ version = "1.0.1" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" GSL_jll = "1b77fbbe-d8ee-58f0-85f9-836ddc23a7a4" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [compat] SpecialFunctions = "0.8, 0.9, 0.10, 1" GSL_jll = "2.6" +REPL = "1.3.0" julia = "1.3.0" [extras] diff --git a/src/GSL.jl b/src/GSL.jl index f93b4f6..8ef40f3 100644 --- a/src/GSL.jl +++ b/src/GSL.jl @@ -1,6 +1,7 @@ module GSL using Markdown +using REPL # For Docs.doc # BEGIN MODULE C # low-level interface From db59ed06a252f12fb52970dc64b73a98034de628 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 4 Jan 2024 11:53:57 -0500 Subject: [PATCH 06/14] Safe Chebyshev wrapper --- src/manual_wrappers.jl | 49 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/manual_wrappers.jl b/src/manual_wrappers.jl index 575389f..a2c07a6 100644 --- a/src/manual_wrappers.jl +++ b/src/manual_wrappers.jl @@ -455,3 +455,52 @@ function hypergeom_e(a, b, x) error("hypergeometric function of order $n is not implemented") end end + +@doc md""" +$(Docs.doc(C.cheb_alloc)) +""" +mutable struct GSLCheb + ptr::Ptr{gsl_cheb_series} + param_ref + gsl_func::gsl_function + function GSLCheb(order) + cs = new(cheb_alloc(order), nothing) + finalizer(cheb_free, cs) + return cs + end +end +export GSLCheb + +Base.cconvert(::Type{Ref{gsl_cheb_series}}, cs::GSLCheb) = cs +Base.unsafe_convert(::Type{Ref{gsl_cheb_series}}, cs::GSLCheb) = cs.ptr +Base.unsafe_convert(::Type{Ptr{gsl_cheb_series}}, cs::GSLCheb) = cs.ptr + +@doc md""" +$(Docs.doc(C.cheb_init)) +""" +function cheb_init(cs::GSLCheb, func::F, a, b) where F + cs.gsl_func, cs.param_ref = wrap_gsl_function(func) + return C.cheb_init(cs, cs.gsl_func, a, b) +end +function cheb_init(cs::GSLCheb, func::gsl_function, a, b) + cs.gsl_func = func + cs.param_ref = nothing + return C.cheb_init(cs, cs.gsl_func, a, b) +end +function cheb_init(cs::Ptr{gsl_cheb_series}, func::gsl_function, a, b) + return C.cheb_init(cs, func, a, b) +end + +@doc md""" +$(Docs.doc(C.cheb_free)) +""" +function cheb_free(cs::Ptr{gsl_cheb_series}) + C.cheb_free(cs) +end +function cheb_free(cs::GSLCheb) + if cs.ptr != C_NULL + C.cheb_free(cs) + cs.ptr = C_NULL + end + return +end From 8ba5e2ce450fb9dd8301d81873939a9f47a869c4 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 4 Jan 2024 12:02:25 -0500 Subject: [PATCH 07/14] Safe FMinimizer wrapper --- src/manual_wrappers.jl | 76 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/src/manual_wrappers.jl b/src/manual_wrappers.jl index a2c07a6..62cc813 100644 --- a/src/manual_wrappers.jl +++ b/src/manual_wrappers.jl @@ -504,3 +504,79 @@ function cheb_free(cs::GSLCheb) end return end + +@doc md""" +$(Docs.doc(C.min_fminimizer_alloc)) +""" +mutable struct GSLMinFMinimizer + ptr::Ptr{gsl_min_fminimizer} + param_ref + gsl_func::gsl_function + function GSLMinFMinimizer(T) + s = new(min_fminimizer_alloc(T), nothing) + finalizer(min_fminimizer_free, s) + return s + end +end +export GSLMinFMinimizer + +Base.cconvert(::Type{Ref{gsl_min_fminimizer}}, s::GSLMinFMinimizer) = s +Base.unsafe_convert(::Type{Ref{gsl_min_fminimizer}}, s::GSLMinFMinimizer) = s.ptr +Base.unsafe_convert(::Type{Ptr{gsl_min_fminimizer}}, s::GSLMinFMinimizer) = s.ptr + +@doc md""" +$(Docs.doc(C.min_fminimizer_set)) +""" +function min_fminimizer_set(s::GSLMinFMinimizer, f::F, + x_minimum, x_lower, x_upper) where F + s.gsl_func, s.param_ref = wrap_gsl_function(f) + return C.min_fminimizer_set(s, s.gsl_func, x_minimum, x_lower, x_upper) +end +function min_fminimizer_set(s::GSLMinFMinimizer, f::gsl_function, + x_minimum, x_lower, x_upper) + s.gsl_func = f + s.param_ref = nothing + return C.min_fminimizer_set(s, s.gsl_func, x_minimum, x_lower, x_upper) +end +function min_fminimizer_set(s::Ptr{gsl_min_fminimizer}, f::gsl_function, + x_minimum, x_lower, x_upper) + return C.min_fminimizer_set(s, f, x_minimum, x_lower, x_upper) +end + +@doc md""" +$(Docs.doc(C.min_fminimizer_set_with_values)) +""" +function min_fminimizer_set_with_values(s::GSLMinFMinimizer, f::F, x_minimum, f_minimum, + x_lower, f_lower, x_upper, f_upper) where F + s.gsl_func, s.param_ref = wrap_gsl_function(f) + return C.min_fminimizer_set_with_values(s, s.gsl_func, x_minimum, f_minimum, + x_lower, f_lower, x_upper, f_upper) +end +function min_fminimizer_set_with_values(s::GSLMinFMinimizer, f::gsl_function, x_minimum, + f_minimum, x_lower, f_lower, x_upper, f_upper) + s.gsl_func = f + s.param_ref = nothing + return C.min_fminimizer_set_with_values(s, s.gsl_func, x_minimum, + f_minimum, x_lower, f_lower, + x_upper, f_upper) +end +function min_fminimizer_set_with_values(s::Ptr{gsl_min_fminimizer}, f::gsl_function, + x_minimum, f_minimum, x_lower, f_lower, + x_upper, f_upper) + return C.min_fminimizer_set_with_values(s, f, x_minimum, f_minimum, x_lower, + f_lower, x_upper, f_upper) +end + +@doc md""" +$(Docs.doc(C.min_fminimizer_free)) +""" +function min_fminimizer_free(s::Ptr{gsl_min_fminimizer}) + C.min_fminimizer_free(s) +end +function min_fminimizer_free(s::GSLMinFMinimizer) + if s.ptr != C_NULL + C.min_fminimizer_free(s) + s.ptr = C_NULL + end + return +end From af18f1411774ea25314b351049bbca3ccc0287b9 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 4 Jan 2024 12:20:12 -0500 Subject: [PATCH 08/14] Safe FSolver wrapper --- src/manual_wrappers.jl | 49 ++++++++++++++++++++++++++++++++++++++++++ test/rootfinding.jl | 27 +++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/src/manual_wrappers.jl b/src/manual_wrappers.jl index 62cc813..540ceff 100644 --- a/src/manual_wrappers.jl +++ b/src/manual_wrappers.jl @@ -580,3 +580,52 @@ function min_fminimizer_free(s::GSLMinFMinimizer) end return end + +@doc md""" +$(Docs.doc(C.root_fsolver_alloc)) +""" +mutable struct GSLRootFSolver + ptr::Ptr{gsl_root_fsolver} + param_ref + gsl_func::gsl_function + function GSLRootFSolver(T) + s = new(root_fsolver_alloc(T), nothing) + finalizer(root_fsolver_free, s) + return s + end +end +export GSLRootFSolver + +Base.cconvert(::Type{Ref{gsl_root_fsolver}}, s::GSLRootFSolver) = s +Base.unsafe_convert(::Type{Ref{gsl_root_fsolver}}, s::GSLRootFSolver) = s.ptr +Base.unsafe_convert(::Type{Ptr{gsl_root_fsolver}}, s::GSLRootFSolver) = s.ptr + +@doc md""" +$(Docs.doc(C.root_fsolver_set)) +""" +function root_fsolver_set(s::GSLRootFSolver, f::F, x_lower, x_upper) where F + s.gsl_func, s.param_ref = wrap_gsl_function(f) + return C.root_fsolver_set(s, s.gsl_func, x_lower, x_upper) +end +function root_fsolver_set(s::GSLRootFSolver, f::gsl_function, x_lower, x_upper) + s.gsl_func = f + s.param_ref = nothing + return C.root_fsolver_set(s, s.gsl_func, x_lower, x_upper) +end +function root_fsolver_set(s::Ptr{gsl_root_fsolver}, f::gsl_function, x_lower, x_upper) + return C.root_fsolver_set(s, f, x_lower, x_upper) +end + +@doc md""" +$(Docs.doc(C.root_fsolver_free)) +""" +function root_fsolver_free(s::Ptr{gsl_root_fsolver}) + C.root_fsolver_free(s) +end +function root_fsolver_free(s::GSLRootFSolver) + if s.ptr != C_NULL + C.root_fsolver_free(s) + s.ptr = C_NULL + end + return +end diff --git a/test/rootfinding.jl b/test/rootfinding.jl index 70c769c..8dc668b 100644 --- a/test/rootfinding.jl +++ b/test/rootfinding.jl @@ -64,6 +64,33 @@ fdf2 = @gsl_function_fdf(myfun, myfun_deriv) end end + @testset "Secant method wrapper" begin + T = gsl_root_fsolver_bisection + @testset "alloc/free" begin + A = GSLRootFSolver(T) + end + @testset "Solve" begin + solver = GSLRootFSolver(T) + root_fsolver_set(solver, myfun, -10, 10) + + status = GSL_CONTINUE + maxiter = 40 + iter = 0 + while status == GSL_CONTINUE + root_fsolver_iterate(solver) + x = root_fsolver_root(solver) + status = root_test_residual(myfun(x), 1e-10) + iter += 1 + if iter==maxiter + error("No convergence") + end + end + @test status == GSL_SUCCESS + x = root_fsolver_root(solver) + @test abs(myfun(x)) < 1e-10 + end + end + @testset "Newton's method" begin T = gsl_root_fdfsolver_newton From aa63a3e6f0155eaab526c79f129a3be048597273 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 4 Jan 2024 14:24:13 -0500 Subject: [PATCH 09/14] Safe FDFSolver wrapper --- src/manual_wrappers.jl | 50 ++++++++++++++++++++++++++++++++++++++++++ test/rootfinding.jl | 48 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/src/manual_wrappers.jl b/src/manual_wrappers.jl index 540ceff..4bb955e 100644 --- a/src/manual_wrappers.jl +++ b/src/manual_wrappers.jl @@ -629,3 +629,53 @@ function root_fsolver_free(s::GSLRootFSolver) end return end + +@doc md""" +$(Docs.doc(C.root_fdfsolver_alloc)) +""" +mutable struct GSLRootFDFSolver + ptr::Ptr{gsl_root_fdfsolver} + param_ref + gsl_func::gsl_function_fdf + function GSLRootFDFSolver(T) + s = new(root_fdfsolver_alloc(T), nothing) + finalizer(root_fdfsolver_free, s) + return s + end +end +export GSLRootFDFSolver + +Base.cconvert(::Type{Ref{gsl_root_fdfsolver}}, s::GSLRootFDFSolver) = s +Base.unsafe_convert(::Type{Ref{gsl_root_fdfsolver}}, s::GSLRootFDFSolver) = s.ptr +Base.unsafe_convert(::Type{Ptr{gsl_root_fdfsolver}}, s::GSLRootFDFSolver) = s.ptr + +@doc md""" +$(Docs.doc(C.root_fdfsolver_set)) +""" +function root_fdfsolver_set(s::GSLRootFDFSolver, + fdf::Union{NTuple{2,Any},NTuple{3,Any}}, root) + s.gsl_func, s.param_ref = wrap_gsl_function_fdf(fdf) + return C.root_fdfsolver_set(s, s.gsl_func, root) +end +function root_fdfsolver_set(s::GSLRootFDFSolver, f::gsl_function_fdf, root) + s.gsl_func = f + s.param_ref = nothing + return C.root_fdfsolver_set(s, s.gsl_func, root) +end +function root_fdfsolver_set(s::Ptr{gsl_root_fdfsolver}, f::gsl_function_fdf, root) + return C.root_fdfsolver_set(s, f, root) +end + +@doc md""" +$(Docs.doc(C.root_fdfsolver_free)) +""" +function root_fdfsolver_free(s::Ptr{gsl_root_fdfsolver}) + C.root_fdfsolver_free(s) +end +function root_fdfsolver_free(s::GSLRootFDFSolver) + if s.ptr != C_NULL + C.root_fdfsolver_free(s) + s.ptr = C_NULL + end + return +end diff --git a/test/rootfinding.jl b/test/rootfinding.jl index 8dc668b..85f6fe7 100644 --- a/test/rootfinding.jl +++ b/test/rootfinding.jl @@ -142,4 +142,52 @@ fdf2 = @gsl_function_fdf(myfun, myfun_deriv) end end + + @testset "Newton's method wrapper" begin + T = gsl_root_fdfsolver_newton + + @testset "alloc/free" begin + A = GSLRootFDFSolver(T) + end + + @testset "Solve" begin + solver = GSLRootFDFSolver(T) + root_fdfsolver_set(solver, (myfun, myfun_deriv, myfun_fdf), 5) + + status = GSL_CONTINUE + iter, maxiter = 0,20 + while status == GSL_CONTINUE + root_fdfsolver_iterate(solver) + x = root_fdfsolver_root(solver) + status = root_test_residual(myfun(x), 1e-10) + iter += 1 + if iter==maxiter + error("No convergence") + end + end + @test status == GSL_SUCCESS + x = root_fdfsolver_root(solver) + @test abs(myfun(x)) < 1e-10 + end + + @testset "Solve / simplestruct" begin + solver = GSLRootFDFSolver(T) + root_fdfsolver_set(solver, (myfun, myfun_deriv), 5) + + status = GSL_CONTINUE + iter, maxiter = 0,20 + while status == GSL_CONTINUE + root_fdfsolver_iterate(solver) + x = root_fdfsolver_root(solver) + status = root_test_residual(myfun(x), 1e-10) + iter += 1 + if iter==maxiter + error("No convergence") + end + end + @test status == GSL_SUCCESS + x = root_fdfsolver_root(solver) + @test abs(myfun(x)) < 1e-10 + end + end end From ed1bb809ff0f0f93b0f1e23b484c82e28f0b7c87 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 4 Jan 2024 14:53:30 -0500 Subject: [PATCH 10/14] Safe MultirootFSolver wrapper --- src/manual_wrappers.jl | 50 ++++++++++++++++++++++++++++++++++++ test/multidim_rootfinding.jl | 39 ++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/src/manual_wrappers.jl b/src/manual_wrappers.jl index 4bb955e..41c8f5c 100644 --- a/src/manual_wrappers.jl +++ b/src/manual_wrappers.jl @@ -679,3 +679,53 @@ function root_fdfsolver_free(s::GSLRootFDFSolver) end return end + +@doc md""" +$(Docs.doc(C.multiroot_fsolver_alloc)) +""" +mutable struct GSLMultirootFSolver + ptr::Ptr{gsl_multiroot_fsolver} + param_ref + gsl_func::gsl_multiroot_function + function GSLMultirootFSolver(T, n) + s = new(multiroot_fsolver_alloc(T, n), nothing) + finalizer(multiroot_fsolver_free, s) + return s + end +end +export GSLMultirootFSolver + +Base.cconvert(::Type{Ref{gsl_multiroot_fsolver}}, s::GSLMultirootFSolver) = s +Base.unsafe_convert(::Type{Ref{gsl_multiroot_fsolver}}, s::GSLMultirootFSolver) = s.ptr +Base.unsafe_convert(::Type{Ptr{gsl_multiroot_fsolver}}, s::GSLMultirootFSolver) = s.ptr + +@doc md""" +$(Docs.doc(C.multiroot_fsolver_set)) +""" +function multiroot_fsolver_set(s::GSLMultirootFSolver, (f, n), x) + s.gsl_func, s.param_ref = wrap_gsl_multiroot_function(f, n) + return C.multiroot_fsolver_set(s, s.gsl_func, x) +end +function multiroot_fsolver_set(s::GSLMultirootFSolver, f::gsl_multiroot_function, x) + s.gsl_func = f + s.param_ref = nothing + return C.multiroot_fsolver_set(s, s.gsl_func, x) +end +function multiroot_fsolver_set(s::Ptr{gsl_multiroot_fsolver}, + f::gsl_multiroot_function, x) + return C.multiroot_fsolver_set(s, f, x) +end + +@doc md""" +$(Docs.doc(C.multiroot_fsolver_free)) +""" +function multiroot_fsolver_free(s::Ptr{gsl_multiroot_fsolver}) + C.multiroot_fsolver_free(s) +end +function multiroot_fsolver_free(s::GSLMultirootFSolver) + if s.ptr != C_NULL + C.multiroot_fsolver_free(s) + s.ptr = C_NULL + end + return +end diff --git a/test/multidim_rootfinding.jl b/test/multidim_rootfinding.jl index 51fb5df..e10419f 100644 --- a/test/multidim_rootfinding.jl +++ b/test/multidim_rootfinding.jl @@ -50,6 +50,45 @@ end vector_free(vinit) end +@testset "Multidimensional Rootfinding without Derivative wrapper" begin + # Initial guess + v0 = Cdouble[1.0, 5.0, 2.0, 1.5, -1.0] + + vinit = vector_alloc(n) + for i=1:n + vector_set(vinit, i-1, v0[i]) + end + + # Setup solver + gsl_func = (fmulti, 5) + dnewton_solver = GSLMultirootFSolver(gsl_multiroot_fsolver_dnewton, n) + multiroot_fsolver_set(dnewton_solver, gsl_func, vinit) + + + # Run + maxiter = 100 + converged, status = 0,0 + for iter = 1:maxiter + status = multiroot_fsolver_iterate(dnewton_solver) + x = multiroot_fsolver_root(dnewton_solver) + dx = multiroot_fsolver_dx(dnewton_solver) + converged = multiroot_test_delta(dx, x, 0, 1e-8) + if status==GSL_SUCCESS && converged==GSL_SUCCESS + break + end + end + + @test converged==GSL_SUCCESS + @test status==GSL_SUCCESS + + vec = multiroot_fsolver_root(dnewton_solver) + v1 = unsafe_wrap(Array{Cdouble}, vector_ptr(vec, 0), n) + + @test v1 ≈ roots atol=1e-5 + + vector_free(vinit) +end + ### WITH DERIVATIVE # Define problem n = 2 From 195e3947be30dec471ca977c415d75cec85946bf Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 4 Jan 2024 14:58:50 -0500 Subject: [PATCH 11/14] Safe MultirootFDFSolver wrapper --- src/manual_wrappers.jl | 58 ++++++++++++++++++++++++++++++++++++ test/multidim_rootfinding.jl | 40 +++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/src/manual_wrappers.jl b/src/manual_wrappers.jl index 41c8f5c..989c69a 100644 --- a/src/manual_wrappers.jl +++ b/src/manual_wrappers.jl @@ -729,3 +729,61 @@ function multiroot_fsolver_free(s::GSLMultirootFSolver) end return end + +@doc md""" +$(Docs.doc(C.multiroot_fdfsolver_alloc)) +""" +mutable struct GSLMultirootFDFSolver + ptr::Ptr{gsl_multiroot_fdfsolver} + param_ref + gsl_func::gsl_multiroot_function_fdf + function GSLMultirootFDFSolver(T, n) + s = new(multiroot_fdfsolver_alloc(T, n), nothing) + finalizer(multiroot_fdfsolver_free, s) + return s + end +end +export GSLMultirootFDFSolver + +Base.cconvert(::Type{Ref{gsl_multiroot_fdfsolver}}, s::GSLMultirootFDFSolver) = s +Base.unsafe_convert(::Type{Ref{gsl_multiroot_fdfsolver}}, s::GSLMultirootFDFSolver) = + s.ptr +Base.unsafe_convert(::Type{Ptr{gsl_multiroot_fdfsolver}}, s::GSLMultirootFDFSolver) = + s.ptr + +@doc md""" +$(Docs.doc(C.multiroot_fdfsolver_set)) +""" +function multiroot_fdfsolver_set(s::GSLMultirootFDFSolver, + (f, df, n)::NTuple{3,Any}, x) + s.gsl_func, s.param_ref = wrap_gsl_multiroot_function_fdf((f, df), n) + return C.multiroot_fdfsolver_set(s, s.gsl_func, x) +end +function multiroot_fdfsolver_set(s::GSLMultirootFDFSolver, + (f, df, fdf, n)::NTuple{4,Any}, x) + s.gsl_func, s.param_ref = wrap_gsl_multiroot_function_fdf((f, df, fdf), n) + return C.multiroot_fdfsolver_set(s, s.gsl_func, x) +end +function multiroot_fdfsolver_set(s::GSLMultirootFDFSolver, + f::gsl_multiroot_function_fdf, x) + s.gsl_func = f + s.param_ref = nothing + return C.multiroot_fdfsolver_set(s, s.gsl_func, x) +end +function multiroot_fdfsolver_set(s::Ptr{gsl_multiroot_fdfsolver}, f::gsl_multiroot_function_fdf, x) + return C.multiroot_fdfsolver_set(s, f, x) +end + +@doc md""" +$(Docs.doc(C.multiroot_fdfsolver_free)) +""" +function multiroot_fdfsolver_free(s::Ptr{gsl_multiroot_fdfsolver}) + C.multiroot_fdfsolver_free(s) +end +function multiroot_fdfsolver_free(s::GSLMultirootFDFSolver) + if s.ptr != C_NULL + C.multiroot_fdfsolver_free(s) + s.ptr = C_NULL + end + return +end diff --git a/test/multidim_rootfinding.jl b/test/multidim_rootfinding.jl index e10419f..97feca0 100644 --- a/test/multidim_rootfinding.jl +++ b/test/multidim_rootfinding.jl @@ -211,6 +211,38 @@ function run_newton(gsl_func::gsl_multiroot_function_fdf) multiroot_fdfsolver_free(newton_solver) vector_free(vinit) end + +function run_newton(gsl_func) + # Initial guess + Random.seed!(1) + v0 = rand(2) + vinit = vector_alloc(n) + for i=1:n + vector_set(vinit, i-1, v0[i]) + end + # Setup solver + newton_solver = GSLMultirootFDFSolver(gsl_multiroot_fdfsolver_newton, n) + multiroot_fdfsolver_set(newton_solver, gsl_func, vinit) + # Run + abstol = 1e-14 + reltol = 1e-14 + maxiter = 10 + converged, status = 0,0 + for iter = 1:maxiter + status = multiroot_fdfsolver_iterate(newton_solver) + @test status==GSL_SUCCESS + x = multiroot_fdfsolver_root(newton_solver) + dx = multiroot_fdfsolver_dx(newton_solver) + converged = multiroot_test_delta(dx, x, abstol, reltol) + if converged==GSL_SUCCESS + break + end + end + @test converged==GSL_SUCCESS + #x = gsl_multiroot_fdfsolver_root(newton_solver) + #@show wrap_gsl_vector(x) + vector_free(vinit) +end # Do tests @testset "Multidimenaional Rootfinding with Derivative" begin @testset "GSL-style wrappers" begin @@ -238,5 +270,13 @@ end gsl_func_vec = @gsl_multiroot_function_fdf(f_vec, df_vec, fdf_vec, n) run_newton(gsl_func_vec) end + @testset "3-argument" begin + gsl_func_vec = (f_vec, df_vec, n) + run_newton(gsl_func_vec) + end + @testset "4-argument" begin + gsl_func_vec = (f_vec, df_vec, fdf_vec, n) + run_newton(gsl_func_vec) + end end end From 3de70a5d6046eab1e44dfa266fcb5beb387ad46d Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 4 Jan 2024 16:45:32 -0500 Subject: [PATCH 12/14] Update document --- README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 512913f..70a3186 100644 --- a/README.md +++ b/README.md @@ -72,9 +72,8 @@ GSL.C.sf_legendre_array(GSL_SF_LEGENDRE_SPHARM, lmax, x, result) ```julia f = x -> x^5+1 df = x -> 5*x^4 -fdf = @gsl_function_fdf(f, df) -solver = root_fdfsolver_alloc(gsl_root_fdfsolver_newton) -root_fdfsolver_set(solver, fdf, -2) +solver = GSLRootFDFSolver(gsl_root_fdfsolver_newton) +root_fdfsolver_set(solver, (f, df), -2) while abs(f(root_fdfsolver_root(solver))) > 1e-10 root_fdfsolver_iterate(solver) end @@ -87,7 +86,6 @@ println("x = ", root_fdfsolver_root(solver)) Extra functionality defined in this package: * Convenience functions `hypergeom` and `hypergeom_e` for the hypergeometric functions. -* Function wrapping macros `@gsl_function`, `@gsl_function_fdf`, `@gsl_multiroot_function` and `@gsl_multiroot_function_fdf` that are used for packaging Julia functions so that they can be passed to GSL. * Functions `wrap_gsl_vector` and `wrap_gsl_matrix` that return a Julia array or matrix pointing to the data in a `gsl_vector` or `gsl_matrix`. In addition, some effort has been put into giving most types and functions proper docstrings, e.g. From f02ba902f46c839932cd571a0ae4c57294588452 Mon Sep 17 00:00:00 2001 From: lxvm Date: Fri, 5 Jan 2024 21:33:47 -0800 Subject: [PATCH 13/14] test Cheb --- test/chebyshev.jl | 20 ++++++++++++++++++++ test/runtests.jl | 8 ++++---- 2 files changed, 24 insertions(+), 4 deletions(-) create mode 100644 test/chebyshev.jl diff --git a/test/chebyshev.jl b/test/chebyshev.jl new file mode 100644 index 0000000..1aaf542 --- /dev/null +++ b/test/chebyshev.jl @@ -0,0 +1,20 @@ +using Test +using GSL + + +@testset "Chebyshev" begin + # tests from FastChebInterp.jl + f = x -> exp(x) / (1 + 2x^2) + f′ = x -> f(x) * (1 - 4x/(1 + 2x^2)) + lb, ub = (-0.3, 0.9) + x = 0.2 + + p = GSLCheb(48) + + cheb_init(p, f, lb, ub) + @test !(cheb_eval_n(p, 10, x) ≈ f(x)) + @test cheb_eval(p, x) ≈ f(x) + cheb_init(p, f′, lb, ub) + @test !(cheb_eval_n(p, 10, x) ≈ f′(x)) + @test cheb_eval(p, x) ≈ f′(x) +end diff --git a/test/runtests.jl b/test/runtests.jl index 6833d2f..96aa445 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,15 +6,15 @@ using SpecialFunctions Random.seed!(1) @testset "GSL" begin - include("error.jl") + include("error.jl") include("hypergeom.jl") include("interp.jl") include("legendre.jl") - include("multidim_rootfinding.jl") + include("multidim_rootfinding.jl") include("numdiff.jl") include("quadrature.jl") - include("rng.jl") + include("rng.jl") include("rootfinding.jl") include("specfunc.jl") + include("chebyshev.jl") end - From 2b2e2aeff6e46df30e86a88ad581d1d65a73e409 Mon Sep 17 00:00:00 2001 From: lxvm Date: Fri, 5 Jan 2024 22:12:20 -0800 Subject: [PATCH 14/14] test FMinimizer --- test/minimization.jl | 31 +++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 32 insertions(+) create mode 100644 test/minimization.jl diff --git a/test/minimization.jl b/test/minimization.jl new file mode 100644 index 0000000..98224dc --- /dev/null +++ b/test/minimization.jl @@ -0,0 +1,31 @@ +using Test +using GSL + +@testset "1d minimization" begin + f = x -> cos(x) + one(x) + @testset "fminimizer $name" for (alg, name) in [ + (gsl_min_fminimizer_goldensection, "goldensection") + (gsl_min_fminimizer_brent, "brent") + (gsl_min_fminimizer_quad_golden, "quad-golden") + ] + m, m_exact = 2.0, pi + a, b = 0.0, 6.0 + epsabs, epsrel = 1e-3, 0e-0 + s = GSLMinFMinimizer(alg) + @test min_fminimizer_name(s) == name + min_fminimizer_set(s, f, m, a, b) + + maxiter = 100 + status = converged = 0 + for i in 1:maxiter + status = min_fminimizer_iterate(s) + m = min_fminimizer_x_minimum(s) + a = min_fminimizer_x_lower(s) + b = min_fminimizer_x_upper(s) + converged = min_test_interval(a, b, epsabs, epsrel) + status == converged == GSL_SUCCESS && break + end + @test status == converged == GSL_SUCCESS + @test m ≈ m_exact atol=epsabs rtol=epsrel + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 96aa445..59b66a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,4 +17,5 @@ Random.seed!(1) include("rootfinding.jl") include("specfunc.jl") include("chebyshev.jl") + include("minimization.jl") end