Skip to content

Commit

Permalink
finitediff: Less complex types
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Mar 8, 2024
1 parent 1d4a3fb commit 35e00e2
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 250 deletions.
6 changes: 4 additions & 2 deletions crates/finitediff/src/array/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ use num::FromPrimitive;

use crate::utils::mod_and_calc_const;

use super::CostFn;

pub fn forward_diff_const<const N: usize, F>(
x: &[F; N],
f: &dyn Fn(&[F; N]) -> Result<F, Error>,
f: CostFn<'_, N, F>,
) -> Result<[F; N], Error>
where
F: Float + FromPrimitive,
Expand All @@ -35,7 +37,7 @@ where

pub fn central_diff_const<const N: usize, F>(
x: &[F; N],
f: &dyn Fn(&[F; N]) -> Result<F, Error>,
f: CostFn<'_, N, F>,
) -> Result<[F; N], Error>
where
F: Float + FromPrimitive,
Expand Down
14 changes: 8 additions & 6 deletions crates/finitediff/src/array/hessian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ use num::{Float, FromPrimitive};

use crate::utils::{mod_and_calc, restore_symmetry_const, KV};

use super::{CostFn, GradientFn};

pub fn forward_hessian_const<const N: usize, F>(
x: &[F; N],
grad: &dyn Fn(&[F; N]) -> Result<[F; N], Error>,
grad: GradientFn<'_, N, F>,
) -> Result<[[F; N]; N], Error>
where
F: Float + FromPrimitive,
Expand All @@ -36,7 +38,7 @@ where

pub fn central_hessian_const<const N: usize, F>(
x: &[F; N],
grad: &dyn Fn(&[F; N]) -> Result<[F; N], Error>,
grad: GradientFn<'_, N, F>,
) -> Result<[[F; N]; N], Error>
where
F: Float + FromPrimitive,
Expand All @@ -59,7 +61,7 @@ where

pub fn forward_hessian_vec_prod_const<const N: usize, F>(
x: &[F; N],
grad: &dyn Fn(&[F; N]) -> Result<[F; N], Error>,
grad: GradientFn<'_, N, F>,
p: &[F; N],
) -> Result<[F; N], Error>
where
Expand All @@ -83,7 +85,7 @@ where

pub fn central_hessian_vec_prod_const<const N: usize, F>(
x: &[F; N],
grad: &dyn Fn(&[F; N]) -> Result<[F; N], Error>,
grad: GradientFn<'_, N, F>,
p: &[F; N],
) -> Result<[F; N], Error>
where
Expand All @@ -108,7 +110,7 @@ where

pub fn forward_hessian_nograd_const<const N: usize, F>(
x: &[F; N],
f: &dyn Fn(&[F; N]) -> Result<F, Error>,
f: CostFn<'_, N, F>,
) -> Result<[[F; N]; N], Error>
where
F: Float + FromPrimitive + AddAssign,
Expand Down Expand Up @@ -149,7 +151,7 @@ where

pub fn forward_hessian_nograd_sparse_const<const N: usize, F>(
x: &[F; N],
f: &dyn Fn(&[F; N]) -> Result<F, Error>,
f: CostFn<'_, N, F>,
indices: Vec<[usize; 2]>,
) -> Result<[[F; N]; N], Error>
where
Expand Down
14 changes: 8 additions & 6 deletions crates/finitediff/src/array/jacobian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ use num::{Float, FromPrimitive};
use crate::pert::PerturbationVectors;
use crate::utils::{mod_and_calc, mod_and_calc_const};

use super::OpFn;

pub fn forward_jacobian_const<const N: usize, const M: usize, F>(
x: &[F; N],
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
fs: OpFn<'_, N, M, F>,
) -> Result<[[F; N]; M], Error>
where
F: Float + FromPrimitive,
Expand All @@ -37,7 +39,7 @@ where

pub fn central_jacobian_const<const N: usize, const M: usize, F>(
x: &[F; N],
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
fs: OpFn<'_, N, M, F>,
) -> Result<[[F; N]; M], Error>
where
F: Float + FromPrimitive,
Expand All @@ -58,7 +60,7 @@ where

pub fn forward_jacobian_vec_prod_const<const N: usize, const M: usize, F>(
x: &[F; N],
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
fs: OpFn<'_, N, M, F>,
p: &[F; N],
) -> Result<[F; M], Error>
where
Expand All @@ -85,7 +87,7 @@ where

pub fn central_jacobian_vec_prod_const<const N: usize, const M: usize, F>(
x: &[F; N],
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
fs: OpFn<'_, N, M, F>,
p: &[F; N],
) -> Result<[F; M], Error>
where
Expand Down Expand Up @@ -117,7 +119,7 @@ where

pub fn forward_jacobian_pert_const<const N: usize, const M: usize, F>(
x: &[F; N],
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
fs: OpFn<'_, N, M, F>,
pert: &PerturbationVectors,
) -> Result<[[F; N]; M], Error>
where
Expand Down Expand Up @@ -149,7 +151,7 @@ where

pub fn central_jacobian_pert_const<const N: usize, const M: usize, F>(
x: &[F; N],
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
fs: OpFn<'_, N, M, F>,
pert: &PerturbationVectors,
) -> Result<[[F; N]; M], Error>
where
Expand Down
125 changes: 60 additions & 65 deletions crates/finitediff/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,154 +26,149 @@ use jacobian::{
forward_jacobian_const, forward_jacobian_pert_const, forward_jacobian_vec_prod_const,
};

pub(crate) type CostFn<'a, const N: usize, F> = &'a dyn Fn(&[F; N]) -> Result<F, Error>;
pub(crate) type GradientFn<'a, const N: usize, F> = &'a dyn Fn(&[F; N]) -> Result<[F; N], Error>;
pub(crate) type OpFn<'a, const N: usize, const M: usize, F> =
&'a dyn Fn(&[F; N]) -> Result<[F; M], Error>;

#[inline(always)]
pub fn forward_diff<const N: usize, Func, F>(f: Func) -> impl Fn(&[F; N]) -> Result<[F; N], Error>
pub fn forward_diff<const N: usize, F>(
f: CostFn<'_, N, F>,
) -> impl Fn(&[F; N]) -> Result<[F; N], Error> + '_
where
Func: Fn(&[F; N]) -> Result<F, Error>,
F: Float + FromPrimitive,
{
move |p: &[F; N]| forward_diff_const(p, &f)
}

#[inline(always)]
pub fn central_diff<const N: usize, Func, F>(f: Func) -> impl Fn(&[F; N]) -> Result<[F; N], Error>
pub fn central_diff<const N: usize, F>(
f: CostFn<'_, N, F>,
) -> impl Fn(&[F; N]) -> Result<[F; N], Error> + '_
where
Func: Fn(&[F; N]) -> Result<F, Error>,
F: Float + FromPrimitive,
{
move |p: &[F; N]| central_diff_const(p, &f)
}

#[inline(always)]
pub fn forward_jacobian<const N: usize, const M: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error>
pub fn forward_jacobian<const N: usize, const M: usize, F>(
f: OpFn<'_, N, M, F>,
) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error> + '_
where
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
F: Float + FromPrimitive,
{
move |p: &[F; N]| forward_jacobian_const(p, &f)
}

#[inline(always)]
pub fn central_jacobian<const N: usize, const M: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error>
pub fn central_jacobian<const N: usize, const M: usize, F>(
f: OpFn<'_, N, M, F>,
) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error> + '_
where
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
F: Float + FromPrimitive,
{
move |p: &[F; N]| central_jacobian_const(p, &f)
}

#[inline(always)]
pub fn forward_jacobian_vec_prod<const N: usize, const M: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error>
pub fn forward_jacobian_vec_prod<const N: usize, const M: usize, F>(
f: OpFn<'_, N, M, F>,
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error> + '_
where
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
F: Float + FromPrimitive,
{
move |p: &[F; N], v: &[F; N]| forward_jacobian_vec_prod_const(p, &f, v)
move |p: &[F; N], v: &[F; N]| forward_jacobian_vec_prod_const(p, f, v)
}

#[inline(always)]
pub fn central_jacobian_vec_prod<const N: usize, const M: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error>
pub fn central_jacobian_vec_prod<const N: usize, const M: usize, F>(
f: OpFn<'_, N, M, F>,
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error> + '_
where
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
F: Float + FromPrimitive,
{
move |p: &[F; N], v: &[F; N]| central_jacobian_vec_prod_const(p, &f, v)
move |p: &[F; N], v: &[F; N]| central_jacobian_vec_prod_const(p, f, v)
}

#[inline(always)]
pub fn forward_jacobian_pert<const N: usize, const M: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error>
pub fn forward_jacobian_pert<const N: usize, const M: usize, F>(
f: OpFn<'_, N, M, F>,
) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error> + '_
where
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
F: Float + FromPrimitive + AddAssign,
{
move |p: &[F; N], pert: &PerturbationVectors| forward_jacobian_pert_const(p, &f, pert)
move |p: &[F; N], pert: &PerturbationVectors| forward_jacobian_pert_const(p, f, pert)
}

#[inline(always)]
pub fn central_jacobian_pert<const N: usize, const M: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error>
pub fn central_jacobian_pert<const N: usize, const M: usize, F>(
f: OpFn<'_, N, M, F>,
) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error> + '_
where
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
F: Float + FromPrimitive + AddAssign,
{
move |p: &[F; N], pert: &PerturbationVectors| central_jacobian_pert_const(p, &f, pert)
move |p: &[F; N], pert: &PerturbationVectors| central_jacobian_pert_const(p, f, pert)
}

#[inline(always)]
pub fn forward_hessian<const N: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error>
pub fn forward_hessian<const N: usize, F>(
f: GradientFn<'_, N, F>,
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
where
Func: Fn(&[F; N]) -> Result<[F; N], Error>,
F: Float + FromPrimitive,
{
move |p: &[F; N]| forward_hessian_const(p, &f)
move |p: &[F; N]| forward_hessian_const(p, f)
}

#[inline(always)]
pub fn central_hessian<const N: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error>
pub fn central_hessian<const N: usize, F>(
f: GradientFn<'_, N, F>,
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
where
Func: Fn(&[F; N]) -> Result<[F; N], Error>,
F: Float + FromPrimitive,
{
move |p: &[F; N]| central_hessian_const(p, &f)
move |p: &[F; N]| central_hessian_const(p, f)
}

#[inline(always)]
pub fn forward_hessian_vec_prod<const N: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error>
pub fn forward_hessian_vec_prod<const N: usize, F>(
f: GradientFn<'_, N, F>,
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error> + '_
where
Func: Fn(&[F; N]) -> Result<[F; N], Error>,
F: Float + FromPrimitive,
{
move |p: &[F; N], v: &[F; N]| forward_hessian_vec_prod_const(p, &f, v)
move |p: &[F; N], v: &[F; N]| forward_hessian_vec_prod_const(p, f, v)
}

#[inline(always)]
pub fn central_hessian_vec_prod<const N: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error>
pub fn central_hessian_vec_prod<const N: usize, F>(
f: GradientFn<'_, N, F>,
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error> + '_
where
Func: Fn(&[F; N]) -> Result<[F; N], Error>,
F: Float + FromPrimitive,
{
move |p: &[F; N], v: &[F; N]| central_hessian_vec_prod_const(p, &f, v)
move |p: &[F; N], v: &[F; N]| central_hessian_vec_prod_const(p, f, v)
}

#[inline(always)]
pub fn forward_hessian_nograd<const N: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error>
pub fn forward_hessian_nograd<const N: usize, F>(
f: CostFn<'_, N, F>,
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
where
Func: Fn(&[F; N]) -> Result<F, Error>,
F: Float + FromPrimitive + AddAssign,
{
move |p: &[F; N]| forward_hessian_nograd_const(p, &f)
move |p: &[F; N]| forward_hessian_nograd_const(p, f)
}

#[inline(always)]
pub fn forward_hessian_nograd_sparse<const N: usize, Func, F>(
f: Func,
) -> impl Fn(&[F; N], Vec<[usize; 2]>) -> Result<[[F; N]; N], Error>
pub fn forward_hessian_nograd_sparse<const N: usize, F>(
f: CostFn<'_, N, F>,
) -> impl Fn(&[F; N], Vec<[usize; 2]>) -> Result<[[F; N]; N], Error> + '_
where
Func: Fn(&[F; N]) -> Result<F, Error>,
F: Float + FromPrimitive + AddAssign,
{
move |p: &[F; N], indices: Vec<[usize; 2]>| forward_hessian_nograd_sparse_const(p, &f, indices)
move |p: &[F; N], indices: Vec<[usize; 2]>| forward_hessian_nograd_sparse_const(p, f, indices)
}

#[cfg(test)]
Expand Down Expand Up @@ -267,7 +262,7 @@ mod tests {

#[test]
fn test_forward_diff_func() {
let grad = forward_diff(f1);
let grad = forward_diff(&f1);
let out = grad(&x1()).unwrap();
let res = [1.0, 2.0];

Expand All @@ -287,7 +282,7 @@ mod tests {

#[test]
fn test_central_diff_func() {
let grad = central_diff(f1);
let grad = central_diff(&f1);
let out = grad(&x1()).unwrap();
let res = [1.0f64, 2.0];

Expand All @@ -296,7 +291,7 @@ mod tests {
}

let p = [1.0f64, 2.0f64];
let grad = central_diff(f1);
let grad = central_diff(&f1);
let out = grad(&p).unwrap();
let res = [1.0f64, 4.0];

Expand Down
6 changes: 4 additions & 2 deletions crates/finitediff/src/ndarr/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ use num::{Float, FromPrimitive};

use crate::utils::*;

use super::CostFn;

pub fn forward_diff_ndarray<F>(
x: &ndarray::Array1<F>,
f: &dyn Fn(&ndarray::Array1<F>) -> Result<F, Error>,
f: CostFn<'_, F>,
) -> Result<ndarray::Array1<F>, Error>
where
F: Float,
Expand All @@ -31,7 +33,7 @@ where

pub fn central_diff_ndarray<F>(
x: &ndarray::Array1<F>,
f: &dyn Fn(&ndarray::Array1<F>) -> Result<F, Error>,
f: CostFn<'_, F>,
) -> Result<ndarray::Array1<F>, Error>
where
F: Float + FromPrimitive,
Expand Down
Loading

0 comments on commit 35e00e2

Please sign in to comment.