Skip to content

f0uriest/quadax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

f2feb06 · Feb 5, 2025

History

93 Commits
Dec 10, 2024
Dec 10, 2024
Dec 10, 2024
Dec 7, 2024
Oct 24, 2023
Oct 24, 2023
Oct 24, 2023
Dec 10, 2024
Oct 24, 2023
Oct 24, 2023
Oct 24, 2023
Oct 24, 2023
Oct 24, 2023
Oct 24, 2023
Dec 7, 2024
Oct 24, 2023
Oct 24, 2023
Feb 4, 2025
Jan 20, 2025
Jul 21, 2024
Jul 21, 2024

Repository files navigation

quadax

License DOI GitHub issues Pypi

Documentation UnitTests Coverage

quadax is a library for numerical quadrature and integration using JAX.

  • vmap-able, jit-able, differentiable.
  • Scalar or vector valued integrands.
  • Finite or infinite domains with discontinuities or singularities within the domain of integration.
  • Globally adaptive Gauss-Kronrod and Clenshaw-Curtis quadrature for smooth integrands (similar to scipy.integrate.quad)
  • Adaptive tanh-sinh quadrature for singular or near singular integrands.
  • Quadrature from sampled values using trapezoidal and Simpsons methods.

Coming soon:

  • Custom JVP/VJP rules (currently AD works by differentiating the loop which isn't the most efficient.)
  • N-D quadrature (cubature)
  • QMC methods
  • Integration with weight functions
  • Sparse grids (maybe, need to play with data structures and JAX)

Installation

quadax is installable with pip:

pip install quadax

Usage

import jax.numpy as jnp
import numpy as np
from quadax import quadgk

fun = lambda t: t * jnp.log(1 + t)

epsabs = epsrel = 1e-5 # by default jax uses 32 bit, higher accuracy requires going to 64 bit
a, b = 0, 1
y, info = quadgk(fun, [a, b], epsabs=epsabs, epsrel=epsrel)
assert info.err < max(epsabs, epsrel*abs(y))
np.testing.assert_allclose(y, 1/4, rtol=epsrel, atol=epsabs)

For full details of various options see the API documentation

About

Numerical quadrature with JAX

Resources

License

Citation

Stars

Watchers

Forks

Packages

No packages published

Languages