Skip to content

Commit

Permalink
Migrate from jax.core to jax.extend.core for several deprecated symbols
Browse files Browse the repository at this point in the history
A number of symbols from jax.core are deprecated as of recent JAX releases; some of them are newly available in jax.extend.core.

PiperOrigin-RevId: 705932268
  • Loading branch information
Jake VanderPlas authored and The oryx Authors committed Dec 18, 2024
1 parent cba9837 commit de8b5cc
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 76 deletions.
43 changes: 22 additions & 21 deletions oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def f(x):
from jax._src import sharding_impls
from jax._src.lax import control_flow as lcf
from jax.experimental import shard_map
import jax.extend as jex
import jax.extend.linear_util as lu
from jax.interpreters import ad
from jax.interpreters import batching
Expand All @@ -178,7 +179,7 @@ def f(x):

Value = Any

sow_p = jax_core.Primitive('sow')
sow_p = jex.core.Primitive('sow')
sow_p.multiple_results = True


Expand Down Expand Up @@ -383,15 +384,15 @@ def __init__(self, parent_trace, context):
self.context = context

def process_primitive(
self, primitive: jax_core.Primitive, vals: List[Any],
self, primitive: jex.core.Primitive, vals: List[Any],
params: Dict[str, Any]) -> Union[Any, List[Any]]:
custom_rule = self.context.get_custom_rule(primitive)
if custom_rule:
return custom_rule(self, *vals, **params)
return self.default_process_primitive(primitive, vals, params)

def default_process_primitive(
self, primitive: jax_core.Primitive, vals: List[Any],
self, primitive: jex.core.Primitive, vals: List[Any],
params: Dict[str, Any]) -> Union[Any, List[Any]]:
if primitive is sow_p:
with jax_core.set_current_trace(self.parent_trace):
Expand All @@ -403,15 +404,15 @@ def default_process_primitive(
return outvals
return outvals[0]

def process_call(self, call_primitive: jax_core.Primitive, f: Any,
def process_call(self, call_primitive: jex.core.Primitive, f: Any,
vals: List[Any], params: Dict[str, Any]):
context = self.context
if call_primitive is nest_p:
return context.process_nest(self, f, *vals, **params)
return context.process_higher_order_primitive(self, call_primitive, f,
vals, params, False)

def process_map(self, call_primitive: jax_core.Primitive, f: Any,
def process_map(self, call_primitive: jex.core.Primitive, f: Any,
vals: List[Any], params: Dict[str, Any]):
return self.context.process_higher_order_primitive(
self, call_primitive, f, vals, params, True)
Expand Down Expand Up @@ -469,7 +470,7 @@ def process_nest(self, trace, f, *vals, scope, name):
raise NotImplementedError

def process_higher_order_primitive(self, trace: HarvestTrace,
call_primitive: jax_core.Primitive, f: Any,
call_primitive: jex.core.Primitive, f: Any,
vals: List[Any],
params: Dict[str, Any], is_map: bool):
raise NotImplementedError
Expand Down Expand Up @@ -783,7 +784,7 @@ def _reap_metadata_wrapper(*args):

def _get_harvest_metadata(closed_jaxpr, settings, *args):
"""Probes a jaxpr for metadata like its sown values."""
fun = lu.wrap_init(jax_core.jaxpr_as_fun(closed_jaxpr))
fun = lu.wrap_init(jex.core.jaxpr_as_fun(closed_jaxpr))

settings = HarvestSettings(settings.tag, settings.blocklist,
settings.allowlist, True)
Expand Down Expand Up @@ -843,7 +844,7 @@ def _reap_scan_rule(trace: HarvestTrace, *vals, length, reverse, jaxpr,
reap_carry_avals[name] = aval
cond_carry_avals[name] = jax_core.get_aval(True)

body_fun = jax_core.jaxpr_as_fun(jaxpr)
body_fun = jex.core.jaxpr_as_fun(jaxpr)

reap_carry_flat_avals = tree_util.tree_leaves(
(reap_carry_avals, cond_carry_avals)
Expand Down Expand Up @@ -931,8 +932,8 @@ def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr,
if mode == 'cond_clobber':
cond_avals[k] = jax_core.get_aval(True)

cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr)
body_fun = jax_core.jaxpr_as_fun(body_jaxpr)
cond_fun = jex.core.jaxpr_as_fun(cond_jaxpr)
body_fun = jex.core.jaxpr_as_fun(body_jaxpr)
reap_settings = dict(
tag=settings.tag,
allowlist=settings.allowlist,
Expand Down Expand Up @@ -1009,7 +1010,7 @@ def _reap_cond_rule(trace, *tracers, branches, linear=None):
_get_harvest_metadata(branch, settings, *ops_vals)
for branch in branches)
_check_branch_metadata(branch_metadatas)
branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches))
branch_funs = tuple(map(jex.core.jaxpr_as_fun, branches))
reaped_branches = tuple(
_call_and_reap(f, **reap_settings) for f in branch_funs)
in_tree = tree_util.tree_structure(ops_avals)
Expand Down Expand Up @@ -1046,9 +1047,9 @@ def _reap_checkpoint_rule(trace, *invals, jaxpr, policy, prevent_cse,
allowlist=settings.allowlist,
blocklist=settings.blocklist,
exclusive=settings.exclusive)
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
closed_jaxpr = jex.core.ClosedJaxpr(jaxpr, ())
reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *invals)
remat_fun = jax_core.jaxpr_as_fun(closed_jaxpr)
remat_fun = jex.core.jaxpr_as_fun(closed_jaxpr)
reaped_remat_fun = _call_and_reap(remat_fun, **reap_settings)
reap_jaxpr, consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access
reaped_remat_fun, tree_util.tree_structure(invals),
Expand Down Expand Up @@ -1078,7 +1079,7 @@ def _oryx_pjit_jaxpr(flat_fun, in_avals):
jaxpr = pe.close_jaxpr(jaxpr)
final_consts = consts
else:
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
jaxpr = jex.core.ClosedJaxpr(jaxpr, consts)
final_consts = []

return jaxpr, final_consts, out_avals
Expand Down Expand Up @@ -1117,7 +1118,7 @@ def _reap_pjit_rule(trace, *invals, **params):
exclusive=settings.exclusive)
closed_jaxpr = params['jaxpr']
reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *invals)
pjit_fun = jax_core.jaxpr_as_fun(closed_jaxpr)
pjit_fun = jex.core.jaxpr_as_fun(closed_jaxpr)
reaped_pjit_fun = lu.wrap_init(_call_and_reap(pjit_fun, **reap_settings))
in_tree = tree_util.tree_structure(invals)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(reaped_pjit_fun, in_tree)
Expand Down Expand Up @@ -1332,7 +1333,7 @@ def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
plant_modes[mode].add(name)
if mode == 'append' and name in plants:
plant_xs_avals[name] = aval
body_fun = jax_core.jaxpr_as_fun(jaxpr)
body_fun = jex.core.jaxpr_as_fun(jaxpr)
all_clobber_plants = {
name: value
for name, value in plants.items()
Expand Down Expand Up @@ -1400,7 +1401,7 @@ def _plant_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr,
' `while_loop`.'
)

body_fun = jax_core.jaxpr_as_fun(body_jaxpr)
body_fun = jex.core.jaxpr_as_fun(body_jaxpr)
plant_settings = dict(
tag=settings.tag,
allowlist=settings.allowlist,
Expand Down Expand Up @@ -1444,7 +1445,7 @@ def _plant_cond_rule(trace, *tracers, branches, linear=None):
for branch in branches)
_check_branch_metadata(branch_metadatas)
plants = trace.context.plants
branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches))
branch_funs = tuple(map(jex.core.jaxpr_as_fun, branches))
planted_branches = tuple(
functools.partial(plant(f, **plant_settings), plants)
for f in branch_funs)
Expand Down Expand Up @@ -1478,9 +1479,9 @@ def _plant_checkpoint_rule(trace, *invals, jaxpr, policy, prevent_cse,
allowlist=settings.allowlist,
blocklist=settings.blocklist,
exclusive=settings.exclusive)
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
closed_jaxpr = jex.core.ClosedJaxpr(jaxpr, ())
plants = trace.context.plants
remat_fun = jax_core.jaxpr_as_fun(closed_jaxpr)
remat_fun = jex.core.jaxpr_as_fun(closed_jaxpr)
planted_remat_fun = functools.partial(
plant(remat_fun, **plant_settings), plants)
plant_jaxpr, consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access
Expand Down Expand Up @@ -1524,7 +1525,7 @@ def _plant_pjit_rule(trace, *invals, **params):
closed_jaxpr = params['jaxpr']
plants = trace.context.plants

pjit_fun = jax_core.jaxpr_as_fun(closed_jaxpr)
pjit_fun = jex.core.jaxpr_as_fun(closed_jaxpr)
planted_pjit_fun = lu.wrap_init(functools.partial(
plant(pjit_fun, **plant_settings), plants))
in_tree = tree_util.tree_structure(invals)
Expand Down
3 changes: 2 additions & 1 deletion oryx/core/interpreters/log_prob_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from jax import random
from jax._src import api_util
from jax._src import core as jax_core
import jax.extend as jex
from jax.extend.core import primitives
from jax.extend import linear_util as lu
import jax.numpy as jnp
Expand All @@ -32,7 +33,7 @@

tfd = tfp.distributions

random_normal_p = jax_core.Primitive('random_normal')
random_normal_p = jex.core.Primitive('random_normal')


def random_normal(rng, name=None):
Expand Down
34 changes: 18 additions & 16 deletions oryx/core/interpreters/propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
when no further progress can be made. Finally, Cell values for all nodes in the
graph are returned.
"""

import collections
import dataclasses
import functools
Expand All @@ -36,8 +37,9 @@
from jax._src import core as jax_core
from jax._src import pjit
from jax._src import sharding_impls
from jax.extend.core import primitives
import jax.extend as jex
from jax.extend import linear_util as lu
from jax.extend.core import primitives
from jax.interpreters import partial_eval as pe

from oryx.core import pytree
Expand All @@ -52,7 +54,7 @@
]

State = Any
VarOrLiteral = Union[jax_core.Var, jax_core.Literal]
VarOrLiteral = Union[jex.core.Var, jex.core.Literal]
safe_map = jax_core.safe_map


Expand Down Expand Up @@ -113,10 +115,10 @@ def unknown(cls, aval):

@dataclasses.dataclass(frozen=True)
class Equation:
"""Hashable wrapper for jax_core.Jaxprs."""
invars: Tuple[jax_core.Var]
outvars: Tuple[jax_core.Var]
primitive: jax_core.Primitive
"""Hashable wrapper for jex.core.Jaxprs."""
invars: Tuple[jex.core.Var]
outvars: Tuple[jex.core.Var]
primitive: jex.core.Primitive
params_flat: Tuple[Any]
params_tree: Any

Expand All @@ -135,7 +137,7 @@ def __hash__(self):
# Override __hash__ to use Literal object IDs because Literals are not
# natively hashable
hashable_invars = tuple(
id(invar) if isinstance(invar, jax_core.Literal) else invar
id(invar) if isinstance(invar, jex.core.Literal) else invar
for invar in self.invars)
return hash(
(hashable_invars, self.outvars, self.primitive, self.params_tree))
Expand All @@ -153,18 +155,18 @@ class Environment:

def __init__(self, cell_type, jaxpr):
self.cell_type = cell_type
self.env: Dict[jax_core.Var, Cell] = {}
self.env: Dict[jex.core.Var, Cell] = {}
self.states: Dict[Equation, Cell] = {}
self.jaxpr: jax_core.Jaxpr = jaxpr
self.jaxpr: jex.core.Jaxpr = jaxpr

def read(self, var: VarOrLiteral) -> Cell:
if isinstance(var, jax_core.Literal):
if isinstance(var, jex.core.Literal):
return self.cell_type.new(var.val)
else:
return self.env.get(var, self.cell_type.unknown(var.aval))

def write(self, var: VarOrLiteral, cell: Cell) -> Cell:
if isinstance(var, jax_core.Literal):
if isinstance(var, jex.core.Literal):
return cell
cur_cell = self.read(var)
if isinstance(var, jax_core.DropVar):
Expand All @@ -180,7 +182,7 @@ def __setitem__(self, key, val):
'`write` method instead.')

def __contains__(self, var: VarOrLiteral):
if isinstance(var, jax_core.Literal):
if isinstance(var, jex.core.Literal):
return True
return var in self.env

Expand All @@ -196,12 +198,12 @@ def construct_graph_representation(eqns):
neighbors = collections.defaultdict(set)
for eqn in eqns:
for var in it.chain(eqn.invars, eqn.outvars):
if isinstance(var, jax_core.Literal):
if isinstance(var, jex.core.Literal):
continue
neighbors[var].add(eqn)

def get_neighbors(var):
if isinstance(var, jax_core.Literal):
if isinstance(var, jex.core.Literal):
return set()
return neighbors[var]

Expand Down Expand Up @@ -238,12 +240,12 @@ def identity_reducer(env, eqn, state, new_state):
@lu.cache
def _to_jaxpr(flat_fun, in_avals):
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
new_jaxpr = jax_core.ClosedJaxpr(new_jaxpr, consts)
new_jaxpr = jex.core.ClosedJaxpr(new_jaxpr, consts)
return new_jaxpr


def propagate(cell_type: Type[Cell],
rules: Dict[jax_core.Primitive, PropagationRule],
rules: Dict[jex.core.Primitive, PropagationRule],
jaxpr: pe.Jaxpr,
constcells: List[Cell],
incells: List[Cell],
Expand Down
21 changes: 11 additions & 10 deletions oryx/core/ppl/effect_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def f(x):
from jax._src import core as jax_core
from jax._src import pjit
from jax._src import sharding_impls
import jax.extend as jex
from jax.extend import linear_util as lu
from jax.interpreters import partial_eval as pe

Expand All @@ -118,14 +119,14 @@ def f(x):

Value = Any
EffectHandler = Callable[..., Any]
VarOrLiteral = Union[jax_core.Var, jax_core.Literal]
Rules = Dict[jax_core.Primitive, EffectHandler]
VarOrLiteral = Union[jex.core.Var, jex.core.Literal]
Rules = Dict[jex.core.Primitive, EffectHandler]

_effect_handler_call_rules: Rules = {}
custom_effect_handler_rules: Rules = {}


def register_call_rule(primitive: jax_core.Primitive,
def register_call_rule(primitive: jex.core.Primitive,
rule: EffectHandler) -> None:
_effect_handler_call_rules[primitive] = rule

Expand All @@ -147,20 +148,20 @@ def __init__(self):

def read(self, var: VarOrLiteral) -> Value:
"""Reads a value from an environment."""
if isinstance(var, jax_core.Literal):
if isinstance(var, jex.core.Literal):
return var.val
if var not in self.env:
raise ValueError(f'Couldn\'t find {var} in environment: {self.env}')
return self.env[var]

def write(self, var: VarOrLiteral, val: Value) -> None:
"""Writes a value to an environment."""
if isinstance(var, jax_core.Literal):
if isinstance(var, jex.core.Literal):
return
self.env[var] = val


def eval_jaxpr_with_state(jaxpr: jax_core.Jaxpr, rules: Rules,
def eval_jaxpr_with_state(jaxpr: jex.core.Jaxpr, rules: Rules,
consts: Sequence[Value], state: Value,
*args: Value) -> Tuple[List[Value], Value]:
"""Interprets a JAXpr and manages an input state with primitive rules.
Expand Down Expand Up @@ -214,7 +215,7 @@ def eval_jaxpr_with_state(jaxpr: jax_core.Jaxpr, rules: Rules,
def default_call_interpreter_rule(primitive: jax_core.CallPrimitive,
rules: Rules, state: Value,
invals: Sequence[Value],
call_jaxpr: jax_core.Jaxpr,
call_jaxpr: jex.core.Jaxpr,
**params: Any) -> Tuple[Value, Value]:
"""Handles simple call primitives like `jax_core.call_p`.
Expand All @@ -232,7 +233,7 @@ def default_call_interpreter_rule(primitive: jax_core.CallPrimitive,
state: The interpreter `state` value at the time of calling evaluating the
call primitive.
invals: The input values to the call primitive.
call_jaxpr: The `jax_core.Jaxpr` that corresponds to the body of the call
call_jaxpr: The `jex.core.Jaxpr` that corresponds to the body of the call
primitive.
**params: The parameters of the call primitive.
Expand All @@ -252,7 +253,7 @@ def default_call_interpreter_rule(primitive: jax_core.CallPrimitive,
@lu.cache
def _to_jaxpr(flat_fun, in_avals):
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
new_jaxpr = jax_core.ClosedJaxpr(new_jaxpr, consts)
new_jaxpr = jex.core.ClosedJaxpr(new_jaxpr, consts)
return new_jaxpr


Expand Down Expand Up @@ -297,7 +298,7 @@ def _pjit_effect_handler_rule(rules, state, invals, **params):


def make_effect_handler(
handlers: Dict[jax_core.Primitive, EffectHandler]
handlers: Dict[jex.core.Primitive, EffectHandler]
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Returns a function transformation that applies a provided set of handlers.
Expand Down
Loading

0 comments on commit de8b5cc

Please sign in to comment.