Skip to content

Commit

Permalink
[pallas] Use the same primitive run_scoped_p for moth mosaic and mo…
Browse files Browse the repository at this point in the history
…saic_gpu

PiperOrigin-RevId: 655751205
  • Loading branch information
cperivol authored and jax authors committed Jul 25, 2024
1 parent aa99113 commit 80a193d
Show file tree
Hide file tree
Showing 14 changed files with 136 additions and 186 deletions.
2 changes: 1 addition & 1 deletion jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ pytype_strict_library(
deps = [
":pallas", # build_cleaner: keep
":tpu_custom_call",
"//jax/_src/pallas",
"//jax/_src/pallas/mosaic:core",
"//jax/_src/pallas/mosaic:lowering",
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
Expand Down Expand Up @@ -672,7 +673,6 @@ pytype_strict_library(
":pallas",
"//jax/_src/pallas/mosaic_gpu:core", # build_cleaner: keep
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/mosaic_gpu:primitives", # build_cleaner: keep
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/triton:primitives",
],
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2552,7 +2552,7 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr):
return region.results


lowering_rules[tpu_primitives.run_scoped_p] = _run_scoped_lowering_rule
lowering_rules[primitives.run_scoped_p] = _run_scoped_lowering_rule

def _device_id_to_logical(
ctx: LoweringRuleContext, device_id,
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/pallas/mosaic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from jax import tree_util
from jax._src import util as jax_util
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives as primitives
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import primitives as tpu_primitives
from jax.experimental import pallas as pl
Expand Down Expand Up @@ -988,7 +989,7 @@ def pipeline(
scratches = ()
if allocations is None:
# run with inline scoped allocations
return tpu_primitives.run_scoped(
return primitives.run_scoped(
lambda allocations: pipeline(
*refs,
scratches=scratches,
Expand Down
34 changes: 0 additions & 34 deletions jax/_src/pallas/mosaic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,20 @@
"""Module for Pallas:TPU-specific JAX primitives and functions."""
from __future__ import annotations

from collections.abc import Callable
import dataclasses
import enum
from typing import Any

import jax
from jax._src import api_util
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import pretty_printer as pp
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pl_core
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.state import discharge as state_discharge
Expand Down Expand Up @@ -157,35 +152,6 @@ def _roll(x, shift):
mlir.register_lowering(roll_p, _roll_lowering_rule)


run_scoped_p = jax_core.Primitive('run_scoped')
run_scoped_p.multiple_results = True


def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any:
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
avals = map(lambda t: t.get_aval(), flat_types)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals)
out = run_scoped_p.bind(*consts, jaxpr=jaxpr)
return tree_util.tree_unflatten(out_tree_thunk(), out)


@run_scoped_p.def_effectful_abstract_eval
def _run_scoped_abstract_eval(*args, jaxpr):
del args
# jaxpr will have effects for its inputs (Refs that are allocated) and for
# constvars (closed over Refs). The effects for the allocated Refs are local
# to the jaxpr and shouldn't propagate out.
nonlocal_effects = {
eff for eff in jaxpr.effects
if not (
isinstance(eff, effects.JaxprInputEffect)
and eff.input_index >= len(jaxpr.constvars)
)
}
return [v.aval for v in jaxpr.outvars], nonlocal_effects


class DeviceIdType(enum.Enum):
MESH = "mesh"
LOGICAL = "logical"
Expand Down
15 changes: 1 addition & 14 deletions jax/_src/pallas/mosaic_gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ py_library(
deps = [
":core",
":pallas_call_registration",
":primitives",
],
)

Expand All @@ -55,11 +54,11 @@ pytype_strict_library(
srcs = ["lowering.py"],
deps = [
":core",
":primitives",
"//jax",
"//jax:core",
"//jax:mlir",
"//jax:mosaic_gpu",
"//jax:pallas",
"//jax:util",
"//jax/_src/lib",
"//jax/_src/pallas",
Expand All @@ -74,15 +73,3 @@ pytype_strict_library(
"//jax/_src/pallas",
],
)

pytype_strict_library(
name = "primitives",
srcs = ["primitives.py"],
deps = [
"//jax:api_util",
"//jax:core",
"//jax:effects",
"//jax:partial_eval",
"//jax:tree_util",
],
)
3 changes: 1 addition & 2 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from jax._src.lib.mlir.dialects import nvgpu as nvgpu_dialect
from jax._src.pallas import core as pl_core
from jax._src.pallas import primitives
from jax._src.pallas.mosaic_gpu import primitives as mosaic_primitives
from jax._src.state import primitives as sp
from jax.experimental.mosaic import gpu as mosaic_gpu
from jax.experimental.mosaic.gpu import dsl as mgpu
Expand Down Expand Up @@ -403,7 +402,7 @@ def _debug_print_lowering_rule(
return ()


@register_lowering_rule(mosaic_primitives.run_scoped_p)
@register_lowering_rule(primitives.run_scoped_p)
def _run_scoped_lowering_rule(
ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr
):
Expand Down
70 changes: 0 additions & 70 deletions jax/_src/pallas/mosaic_gpu/primitives.py

This file was deleted.

48 changes: 47 additions & 1 deletion jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,23 @@
import enum
import functools
import string
from typing import Any
from typing import Any, Callable

import jax
from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import api_util
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import pretty_printer as pp
from jax._src import state
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
Expand Down Expand Up @@ -797,3 +800,46 @@ def debug_print_lowering_rule(ctx, *args, **params):
has_side_effect=True,
)
return result


run_scoped_p = jax_core.Primitive("run_scoped")
run_scoped_p.multiple_results = True


def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any:
"""Call the function with allocated references.
Args:
f: The function that generatest the jaxpr.
*types: The types of the function's positional arguments.
**kw_types: The types of the function's keyword arguments.
"""

flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
avals = [t.get_aval() for t in flat_types]
# Turn the function into a jaxpr. The body of run_scoped may have
# effects (IO) on constvars (i.e. variables inherited from the
# parent scope). Jax can't reason about effects to references that
# are not in the invars of an operation so we just put them all
# there.
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals)
out = run_scoped_p.bind(*consts, jaxpr=jaxpr)
return tree_util.tree_unflatten(out_tree_thunk(), out)


@run_scoped_p.def_effectful_abstract_eval
def _run_scoped_abstract_eval(*args, jaxpr):
del args
# jaxpr will have effects for its inputs (Refs that are allocated) and for
# constvars (closed over Refs). The effects for the allocated Refs are local
# to the jaxpr and shouldn't propagate out.
nonlocal_effects = {
eff
for eff in jaxpr.effects
if not (
isinstance(eff, effects.JaxprInputEffect)
and eff.input_index >= len(jaxpr.constvars)
)
}
return [v.aval for v in jaxpr.outvars], nonlocal_effects
10 changes: 6 additions & 4 deletions jax/experimental/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

"""Module for Pallas, a JAX extension for custom kernels.
See the Pallas documentation at https://jax.readthedocs.io/en/latest/pallas.html.
See the Pallas documentation at
https://jax.readthedocs.io/en/latest/pallas.html.
"""

from jax._src import pallas
from jax._src.deprecations import register as _register_deprecation
from jax._src.pallas.core import Blocked
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import no_block_spec
from jax._src.pallas.core import IndexingMode
from jax._src.pallas.core import Blocked
from jax._src.pallas.core import no_block_spec
from jax._src.pallas.core import Unblocked
from jax._src.pallas.core import unblocked
from jax._src.pallas.pallas_call import pallas_call
Expand All @@ -41,6 +43,7 @@
from jax._src.pallas.primitives import multiple_of
from jax._src.pallas.primitives import num_programs
from jax._src.pallas.primitives import program_id
from jax._src.pallas.primitives import run_scoped
from jax._src.pallas.primitives import store
from jax._src.pallas.primitives import swap
from jax._src.pallas.utils import cdiv
Expand All @@ -52,6 +55,5 @@
from jax._src.state.indexing import Slice
from jax._src.state.primitives import broadcast_to

from jax._src.deprecations import register as _register_deprecation
_register_deprecation("pallas-block-spec-order")
del _register_deprecation
5 changes: 4 additions & 1 deletion jax/experimental/pallas/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
from jax._src.pallas.mosaic.primitives import repeat
from jax._src.pallas.mosaic.primitives import roll
from jax._src.pallas.mosaic.primitives import run_scoped
from jax._src.pallas.mosaic.primitives import semaphore_read
from jax._src.pallas.mosaic.primitives import semaphore_signal
from jax._src.pallas.mosaic.primitives import semaphore_wait
Expand All @@ -49,6 +48,10 @@
from jax._src.pallas.mosaic.random import to_pallas_key
from jax._src.tpu_custom_call import CostEstimate

# TODO(cperivol): Temporary alias to the global run_scoped. Remove
# this once everyone has migrated to the pallas core one.
from jax._src.pallas.primitives import run_scoped

import types
from jax._src.pallas.mosaic.verification import assume
from jax._src.pallas.mosaic.verification import pretend
Expand Down
5 changes: 2 additions & 3 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
import jax._src.pallas.mosaic_gpu.core as plgpu_core
import jax._src.pallas.mosaic_gpu.primitives as plgpu_prims
import jax._src.pallas.mosaic_gpu.core as plgpu
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -122,7 +121,7 @@ def body(tmp_ref):
tmp_ref[...] = x_ref[...] + 1.0
return tmp_ref[...]

tmp = plgpu_prims.run_scoped(body, plgpu_core.SMEM((8, 128), jnp.float32))
tmp = pl.run_scoped(body, plgpu.SMEM((8, 128), jnp.float32))
self.assertEqual(tmp.shape, (8, 128))
o_ref[...] = tmp

Expand Down
4 changes: 2 additions & 2 deletions tests/pallas/tpu_pallas_mesh_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def inner(refs):
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
pl.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = state_discharge.run_state(inner)((x, y))
Expand Down Expand Up @@ -84,7 +84,7 @@ def alloc(x_vmem_ref, y_vmem_ref, sem):
y = x_vmem_ref[...] + jax.lax.axis_index("x")
y_vmem_ref[...] = y
pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait()
pltpu.run_scoped(
pl.run_scoped(
alloc,
pltpu.VMEM((slc_size, 128), x_ref.dtype),
pltpu.VMEM((slc_size, 128), y_ref.dtype),
Expand Down
Loading

0 comments on commit 80a193d

Please sign in to comment.