Skip to content

Commit

Permalink
Update _vector_grad.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 6, 2025
1 parent ff8c10d commit 88cf060
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions brainunit/autograd/_vector_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,6 @@ def vector_grad(
"""
Unit-aware compute the gradient of a vector with respect to the input.
Example::
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def simple_function(x):
... return x ** 2
>>> vector_grad_fn = u.autograd.vector_grad(simple_function)
>>> vector_grad_fn(jnp.array([3.0, 4.0]) * u.ms)
[6.0, 8.0] * ms
>>> vector_grad_fn = u.autograd.vector_grad(simple_function, return_value=True)
>>> grad, value = vector_grad_fn(jnp.array([3.0, 4.0]) * u.ms)
>>> grad
[6.0, 8.0] * ms
>>> value
[9.0, 16.0] * ms ** 2
Args:
fun: A Python callable that computes a scalar loss given arguments.
argnums: Optional, an integer or a tuple of integers. The argument number(s) to differentiate with respect to.
Expand All @@ -67,6 +52,25 @@ def vector_grad(
Returns:
A function that computes the gradient of `fun` with respect to
the argument(s) indicated by `argnums`.
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def simple_function(x):
... return x ** 2
>>> vector_grad_fn = u.autograd.vector_grad(simple_function)
>>> vector_grad_fn(jnp.array([3.0, 4.0]) * u.ms)
[6.0, 8.0] * ms
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def simple_function(x):
... return x ** 2
>>> vector_grad_fn = u.autograd.vector_grad(simple_function, return_value=True)
>>> grad, value = vector_grad_fn(jnp.array([3.0, 4.0]) * u.ms)
>>> grad
[6.0, 8.0] * ms
>>> value
[9.0, 16.0] * ms ** 2
"""

_check_callable(func)
Expand Down

0 comments on commit 88cf060

Please sign in to comment.