Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 6, 2025
1 parent adc0eab commit 6f9bdee
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 52 deletions.
14 changes: 0 additions & 14 deletions brainunit/autograd/_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@ def hessian(
Physical unit-aware version of `jax.hessian <https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html>`_,
computing Hessian of ``fun`` as a dense array.
Example::
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def scalar_function1(x):
... return x ** 2 + 3 * x * u.ms + 2 * u.msecond2
>>> hess_fn = u.autograd.hessian(scalar_function1)
>>> hess_fn(jnp.array(1.0) * u.ms)
[2]
>>> def scalar_function2(x):
... return x ** 3 + 3 * x * u.msecond2 + 2 * u.msecond3
>>> hess_fn = u.autograd.hessian(scalar_function2)
>>> hess_fn(jnp.array(1.0) * u.ms)
[6] * ms
Args:
fun: Function whose Hessian is to be computed. Its arguments at positions
specified by ``argnums`` should be arrays, scalars, or standard Python
Expand Down
38 changes: 0 additions & 38 deletions brainunit/autograd/_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,6 @@ def jacrev(
"""
Physical unit-aware version of `jax.jacrev <https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html>`_.
Example::
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def simple_function1(x):
... return x ** 2
>>> jac_fn = u.autograd.jacrev(simple_function)
>>> jac_fn(jnp.array(3.0) * u.ms)
6.0 * ms
>>> def simple_function2(x, y):
... return x * y
>>> jac_fn = u.autograd.jacrev(simple_function2, argnums=(0, 1))
>>> x = jnp.array([3.0, 4.0]) * u.ohm
>>> y = jnp.array([5.0, 6.0]) * u.mA
>>> jac_fn(x, y)
([[5., 0.],
[0., 6.]] * mA,
[[3., 0.],
[0., 4.]] * ohm)
Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers. Specifies which
Expand Down Expand Up @@ -259,25 +240,6 @@ def jacfwd(
"""
Physical unit-aware version of `jax.jacfwd <https://jax.readthedocs.io/en/latest/_autosummary/jax.jacfwd.html>`_.
Example::
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def simple_function(x):
... return x ** 2
>>> jac_fn = u.autograd.jacfwd(simple_function)
>>> jac_fn(jnp.array(3.0) * u.ms)
6.0 * ms
>>> def simple_function(x, y):
... return x * y
>>> jac_fn = u.autograd.jacfwd(simple_function, argnums=(0, 1))
>>> x = jnp.array([3.0, 4.0]) * u.ohm
>>> y = jnp.array([5.0, 6.0]) * u.mA
>>> jac_fn(x, y)
([[5., 0.],
[0., 6.]] * mA,
[[3., 0.],
[0., 4.]] * ohm)
Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers. Specifies which
Expand Down

0 comments on commit 6f9bdee

Please sign in to comment.