Skip to content

Commit

Permalink
fix array initialization problem when import package (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 authored Dec 13, 2024
1 parent 958463c commit a42340d
Showing 1 changed file with 43 additions and 7 deletions.
50 changes: 43 additions & 7 deletions braintaichi/_primitive/_ad_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,61 @@

# -*- coding: utf-8 -*-

import functools
from functools import partial

import jax
from jax import tree_util
from jax.core import Primitive
from jax.interpreters import ad

__all__ = [
'defjvp',
]


def defjvp(primitive: Primitive, *jvp_rules):
"""
Define JVP rules for any JAX primitive.
def defjvp(primitive, *jvp_rules):
"""Define JVP rules for any JAX primitive.
This function is similar to ``jax.interpreters.ad.defjvp``.
However, this JAX function only supports primitives with ``multiple_results=False``.
``braintaichi.defjvp`` enables to define the independent JVP rule for
However, the JAX one only supports primitive with ``multiple_results=False``.
``brainpy.math.defjvp`` enables to define the independent JVP rule for
each input parameter no matter ``multiple_results=False/True``.
For examples, please see ``test_ad_support.py``.
Args:
primitive: Primitive, XLACustomOp.
*jvp_rules: The JVP translation rule for each primal.
"""
from brainstate.event._xla_custom_op import defjvp as defjvp_custom_op
defjvp_custom_op(primitive, *jvp_rules)
assert isinstance(primitive, Primitive)
if primitive.multiple_results:
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
else:
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)


def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
assert primitive.multiple_results
val_out = tuple(primitive.bind(*primals, **params))
tree = tree_util.tree_structure(val_out)
tangents_out = []
for rule, t in zip(jvp_rules, tangents):
if rule is not None and type(t) is not ad.Zero:
r = tuple(rule(t, *primals, **params))
tangents_out.append(r)
assert tree_util.tree_structure(r) == tree
r = functools.reduce(
_add_tangents,
tangents_out,
tree_util.tree_map(
# compatible with JAX 0.4.34
lambda a: ad.Zero.from_primal_value(a) if jax.__version__ >= '0.4.34' else ad.Zero.from_value(a),
val_out
)
)
return val_out, r


def _add_tangents(xs, ys):
return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))

0 comments on commit a42340d

Please sign in to comment.