diff --git a/jax/BUILD b/jax/BUILD index 171ec9599bdb..4556fa285220 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -626,6 +626,7 @@ pytype_strict_library( deps = [ ":pallas", # build_cleaner: keep ":tpu_custom_call", + "//jax/_src/pallas", "//jax/_src/pallas/mosaic:core", "//jax/_src/pallas/mosaic:lowering", "//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep @@ -672,7 +673,6 @@ pytype_strict_library( ":pallas", "//jax/_src/pallas/mosaic_gpu:core", # build_cleaner: keep "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep - "//jax/_src/pallas/mosaic_gpu:primitives", # build_cleaner: keep "//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep "//jax/_src/pallas/triton:primitives", ], diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index b8905916a88e..31e47cc256ea 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2552,7 +2552,7 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): return region.results -lowering_rules[tpu_primitives.run_scoped_p] = _run_scoped_lowering_rule +lowering_rules[primitives.run_scoped_p] = _run_scoped_lowering_rule def _device_id_to_logical( ctx: LoweringRuleContext, device_id, diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 08ec28d71546..1935f89f1699 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -28,6 +28,7 @@ from jax import tree_util from jax._src import util as jax_util from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives as primitives from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import primitives as tpu_primitives from jax.experimental import pallas as pl @@ -988,7 +989,7 @@ def pipeline( scratches = () if allocations is None: # run with inline scoped allocations - return tpu_primitives.run_scoped( + return primitives.run_scoped( lambda allocations: pipeline( *refs, scratches=scratches, diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 263c8aa15586..21dd63e8a811 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -15,17 +15,13 @@ """Module for Pallas:TPU-specific JAX primitives and functions.""" from __future__ import annotations -from collections.abc import Callable import dataclasses import enum from typing import Any import jax -from jax._src import api_util from jax._src import core as jax_core from jax._src import dtypes -from jax._src import effects -from jax._src import linear_util as lu from jax._src import pretty_printer as pp from jax._src import state from jax._src import tree_util @@ -33,7 +29,6 @@ from jax._src.state import indexing from jax._src.state import primitives as sp from jax._src.interpreters import mlir -from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pl_core from jax._src.pallas.mosaic import core as tpu_core from jax._src.state import discharge as state_discharge @@ -157,35 +152,6 @@ def _roll(x, shift): mlir.register_lowering(roll_p, _roll_lowering_rule) -run_scoped_p = jax_core.Primitive('run_scoped') -run_scoped_p.multiple_results = True - - -def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any: - flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) - flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) - avals = map(lambda t: t.get_aval(), flat_types) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals) - out = run_scoped_p.bind(*consts, jaxpr=jaxpr) - return tree_util.tree_unflatten(out_tree_thunk(), out) - - -@run_scoped_p.def_effectful_abstract_eval -def _run_scoped_abstract_eval(*args, jaxpr): - del args - # jaxpr will have effects for its inputs (Refs that are allocated) and for - # constvars (closed over Refs). The effects for the allocated Refs are local - # to the jaxpr and shouldn't propagate out. - nonlocal_effects = { - eff for eff in jaxpr.effects - if not ( - isinstance(eff, effects.JaxprInputEffect) - and eff.input_index >= len(jaxpr.constvars) - ) - } - return [v.aval for v in jaxpr.outvars], nonlocal_effects - - class DeviceIdType(enum.Enum): MESH = "mesh" LOGICAL = "logical" diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index b4168271bfac..038826a663e8 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -34,7 +34,6 @@ py_library( deps = [ ":core", ":pallas_call_registration", - ":primitives", ], ) @@ -55,11 +54,11 @@ pytype_strict_library( srcs = ["lowering.py"], deps = [ ":core", - ":primitives", "//jax", "//jax:core", "//jax:mlir", "//jax:mosaic_gpu", + "//jax:pallas", "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", @@ -74,15 +73,3 @@ pytype_strict_library( "//jax/_src/pallas", ], ) - -pytype_strict_library( - name = "primitives", - srcs = ["primitives.py"], - deps = [ - "//jax:api_util", - "//jax:core", - "//jax:effects", - "//jax:partial_eval", - "//jax:tree_util", - ], -) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index c1f07d52ba20..6e619abd8216 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -34,7 +34,6 @@ from jax._src.lib.mlir.dialects import nvgpu as nvgpu_dialect from jax._src.pallas import core as pl_core from jax._src.pallas import primitives -from jax._src.pallas.mosaic_gpu import primitives as mosaic_primitives from jax._src.state import primitives as sp from jax.experimental.mosaic import gpu as mosaic_gpu from jax.experimental.mosaic.gpu import dsl as mgpu @@ -403,7 +402,7 @@ def _debug_print_lowering_rule( return () -@register_lowering_rule(mosaic_primitives.run_scoped_p) +@register_lowering_rule(primitives.run_scoped_p) def _run_scoped_lowering_rule( ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr ): diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py deleted file mode 100644 index b7905f1f0f06..000000000000 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module for Pallas:MosaicGPU-specific JAX primitives and functions.""" -from __future__ import annotations - -from collections.abc import Callable -from typing import Any - -from jax._src import api_util -from jax._src import core as jax_core -from jax._src import effects -from jax._src import linear_util as lu -from jax._src import tree_util -from jax._src.interpreters import partial_eval as pe - -run_scoped_p = jax_core.Primitive("run_scoped") -run_scoped_p.multiple_results = True - - -# TODO(cperivol): consolidate run_scoped with the pallas TPU version -# of the same op. -def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any: - """Call the function with allocated shared mem references. - - Args: - f: The function that generatest the jaxpr. - *types: The types of the function's positional arguments. - **kw_types: The types of the function's keyword arguments. - """ - - flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) - flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) - avals = [t.get_aval() for t in flat_types] - # Turn the function into a jaxpr. The body of run_scoped may have - # effects (IO) on constvars (i.e. variables inherited from the - # parent scope). Jax can't reason about effects to references that - # are not in the invars of an operation so we just put them all - # there. - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals) - out = run_scoped_p.bind(*consts, jaxpr=jaxpr) - return tree_util.tree_unflatten(out_tree_thunk(), out) - - -@run_scoped_p.def_effectful_abstract_eval -def _run_scoped_abstract_eval(*args, jaxpr): - del args - # jaxpr will have effects for its inputs (Refs that are allocated) and for - # constvars (closed over Refs). The effects for the allocated Refs are local - # to the jaxpr and shouldn't propagate out. - nonlocal_effects = { - eff - for eff in jaxpr.effects - if not ( - isinstance(eff, effects.JaxprInputEffect) - and eff.input_index >= len(jaxpr.constvars) - ) - } - return [v.aval for v in jaxpr.outvars], nonlocal_effects diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index ce449bf978bb..f1adf53b4179 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -19,20 +19,23 @@ import enum import functools import string -from typing import Any +from typing import Any, Callable import jax from jax import lax from jax import tree_util from jax._src import ad_util +from jax._src import api_util from jax._src import core as jax_core from jax._src import dtypes from jax._src import effects +from jax._src import linear_util as lu from jax._src import pretty_printer as pp from jax._src import state from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching +from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core from jax._src.state import discharge as state_discharge from jax._src.state import indexing @@ -797,3 +800,46 @@ def debug_print_lowering_rule(ctx, *args, **params): has_side_effect=True, ) return result + + +run_scoped_p = jax_core.Primitive("run_scoped") +run_scoped_p.multiple_results = True + + +def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any: + """Call the function with allocated references. + + Args: + f: The function that generatest the jaxpr. + *types: The types of the function's positional arguments. + **kw_types: The types of the function's keyword arguments. + """ + + flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) + flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) + avals = [t.get_aval() for t in flat_types] + # Turn the function into a jaxpr. The body of run_scoped may have + # effects (IO) on constvars (i.e. variables inherited from the + # parent scope). Jax can't reason about effects to references that + # are not in the invars of an operation so we just put them all + # there. + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals) + out = run_scoped_p.bind(*consts, jaxpr=jaxpr) + return tree_util.tree_unflatten(out_tree_thunk(), out) + + +@run_scoped_p.def_effectful_abstract_eval +def _run_scoped_abstract_eval(*args, jaxpr): + del args + # jaxpr will have effects for its inputs (Refs that are allocated) and for + # constvars (closed over Refs). The effects for the allocated Refs are local + # to the jaxpr and shouldn't propagate out. + nonlocal_effects = { + eff + for eff in jaxpr.effects + if not ( + isinstance(eff, effects.JaxprInputEffect) + and eff.input_index >= len(jaxpr.constvars) + ) + } + return [v.aval for v in jaxpr.outvars], nonlocal_effects diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index b4c402eae43e..0b205c8e815a 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -14,14 +14,16 @@ """Module for Pallas, a JAX extension for custom kernels. -See the Pallas documentation at https://jax.readthedocs.io/en/latest/pallas.html. +See the Pallas documentation at +https://jax.readthedocs.io/en/latest/pallas.html. """ from jax._src import pallas +from jax._src.deprecations import register as _register_deprecation +from jax._src.pallas.core import Blocked from jax._src.pallas.core import BlockSpec -from jax._src.pallas.core import no_block_spec from jax._src.pallas.core import IndexingMode -from jax._src.pallas.core import Blocked +from jax._src.pallas.core import no_block_spec from jax._src.pallas.core import Unblocked from jax._src.pallas.core import unblocked from jax._src.pallas.pallas_call import pallas_call @@ -41,6 +43,7 @@ from jax._src.pallas.primitives import multiple_of from jax._src.pallas.primitives import num_programs from jax._src.pallas.primitives import program_id +from jax._src.pallas.primitives import run_scoped from jax._src.pallas.primitives import store from jax._src.pallas.primitives import swap from jax._src.pallas.utils import cdiv @@ -52,6 +55,5 @@ from jax._src.state.indexing import Slice from jax._src.state.primitives import broadcast_to -from jax._src.deprecations import register as _register_deprecation _register_deprecation("pallas-block-spec-order") del _register_deprecation diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 279e6955e393..98f9d5c0f03e 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -40,7 +40,6 @@ from jax._src.pallas.mosaic.primitives import make_async_remote_copy from jax._src.pallas.mosaic.primitives import repeat from jax._src.pallas.mosaic.primitives import roll -from jax._src.pallas.mosaic.primitives import run_scoped from jax._src.pallas.mosaic.primitives import semaphore_read from jax._src.pallas.mosaic.primitives import semaphore_signal from jax._src.pallas.mosaic.primitives import semaphore_wait @@ -49,6 +48,10 @@ from jax._src.pallas.mosaic.random import to_pallas_key from jax._src.tpu_custom_call import CostEstimate +# TODO(cperivol): Temporary alias to the global run_scoped. Remove +# this once everyone has migrated to the pallas core one. +from jax._src.pallas.primitives import run_scoped + import types from jax._src.pallas.mosaic.verification import assume from jax._src.pallas.mosaic.verification import pretend diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 9bcbe6ab2ab8..5a7c59f77a00 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -17,8 +17,7 @@ from absl.testing import parameterized import jax from jax._src import test_util as jtu -import jax._src.pallas.mosaic_gpu.core as plgpu_core -import jax._src.pallas.mosaic_gpu.primitives as plgpu_prims +import jax._src.pallas.mosaic_gpu.core as plgpu from jax.experimental import pallas as pl import jax.numpy as jnp import numpy as np @@ -122,7 +121,7 @@ def body(tmp_ref): tmp_ref[...] = x_ref[...] + 1.0 return tmp_ref[...] - tmp = plgpu_prims.run_scoped(body, plgpu_core.SMEM((8, 128), jnp.float32)) + tmp = pl.run_scoped(body, plgpu.SMEM((8, 128), jnp.float32)) self.assertEqual(tmp.shape, (8, 128)) o_ref[...] = tmp diff --git a/tests/pallas/tpu_pallas_mesh_test.py b/tests/pallas/tpu_pallas_mesh_test.py index 61ad45a5eaa5..0df759aec724 100644 --- a/tests/pallas/tpu_pallas_mesh_test.py +++ b/tests/pallas/tpu_pallas_mesh_test.py @@ -53,7 +53,7 @@ def inner(refs): def kernel(): def alloc(sem): pltpu.async_copy(x_ref, y_ref, sem).wait() - pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA) + pl.run_scoped(alloc, pltpu.SemaphoreType.DMA) shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, check_rep=False)() _, y = state_discharge.run_state(inner)((x, y)) @@ -84,7 +84,7 @@ def alloc(x_vmem_ref, y_vmem_ref, sem): y = x_vmem_ref[...] + jax.lax.axis_index("x") y_vmem_ref[...] = y pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait() - pltpu.run_scoped( + pl.run_scoped( alloc, pltpu.VMEM((slc_size, 128), x_ref.dtype), pltpu.VMEM((slc_size, 128), y_ref.dtype), diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 6fa89a8136e5..e61f0dfa56b3 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -1469,7 +1469,7 @@ def run(acc_scratch_ref): accum_dtype = ( jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32 ) - pltpu.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) + pl.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) num_cores = jax.devices()[0].num_cores return pl.pallas_call( diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index d4586bacbbaf..88bcde7eea16 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -749,7 +749,7 @@ def body(temp_ref): temp_ref[...] = jnp.ones_like(temp_ref) x_ref[...] = 4 * y_ref[...] + temp_ref[...] - pltpu.run_scoped(body, pltpu.VMEM((8,), jnp.float32)) + pl.run_scoped(body, pltpu.VMEM((8,), jnp.float32)) return [] jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( @@ -768,7 +768,7 @@ def body(x_ref): x_ref[...] = jnp.ones_like(x_ref) y_ref[...] = 4 * x_ref[...] - pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32)) + pl.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32)) o = self.pallas_call( kernel, @@ -783,7 +783,7 @@ def body(x_ref): x_ref[0] += 1 return x_ref[0] + 2 - out = pltpu.run_scoped(body, pltpu.SMEM((1,), jnp.int32)) + out = pl.run_scoped(body, pltpu.SMEM((1,), jnp.int32)) y_ref[0] = out o = self.pallas_call( @@ -803,7 +803,7 @@ def body(x_ref): x_ref[0] += 1 return x_ref[0] + 2, x_ref[0] - out = pltpu.run_scoped(body, pltpu.SMEM((1,), jnp.int32)) + out = pl.run_scoped(body, pltpu.SMEM((1,), jnp.int32)) y_ref[0], y_ref[1] = out o = self.pallas_call( @@ -822,7 +822,7 @@ def body(x_ref): x_ref[...] = jnp.ones_like(x_ref) return x_ref[...] + 1 - out = pltpu.run_scoped(body, pltpu.VMEM((16, 128), jnp.int32)) + out = pl.run_scoped(body, pltpu.VMEM((16, 128), jnp.int32)) y_ref[...] = out o = self.pallas_call( @@ -841,7 +841,7 @@ def body(x_ref): x_ref[...] = jnp.ones_like(x_ref) return x_ref[...] + 1 - out = pltpu.run_scoped(body, pltpu.VMEM((17, 128), jnp.int32)) + out = pl.run_scoped(body, pltpu.VMEM((17, 128), jnp.int32)) y_ref[...] = out o = self.pallas_call( @@ -861,9 +861,9 @@ def body(x_ref): def inner_body(z_ref): z_ref[...] = jnp.ones_like(z_ref) x_ref[...] = z_ref[...] - pltpu.run_scoped(inner_body, pltpu.VMEM((8, 128), jnp.float32)) + pl.run_scoped(inner_body, pltpu.VMEM((8, 128), jnp.float32)) y_ref[...] = 4 * x_ref[...] - pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32)) + pl.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32)) o = self.pallas_call( kernel, @@ -875,7 +875,7 @@ def test_can_allocate_semaphore(self): def kernel(y_ref): def body(sem1): pass - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) + pl.run_scoped(body, pltpu.SemaphoreType.DMA) jax.block_until_ready(self.pallas_call( kernel, @@ -886,8 +886,7 @@ def test_can_allocate_multiple_semaphores(self): def kernel(y_ref): def body(sem1, sem2): pass - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA, - pltpu.SemaphoreType.REGULAR) + pl.run_scoped(body, pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.REGULAR) jax.block_until_ready(self.pallas_call( kernel, @@ -901,8 +900,9 @@ def body(dma_sems, sems): self.assertTupleEqual(sems.shape, (3,)) self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA((4,)), - pltpu.SemaphoreType.REGULAR((3,))) + pl.run_scoped( + body, pltpu.SemaphoreType.DMA((4,)), pltpu.SemaphoreType.REGULAR((3,)) + ) jax.block_until_ready(self.pallas_call( kernel, @@ -936,12 +936,12 @@ def kernel(y_ref): def body(sem): pltpu.semaphore_signal(sem) pltpu.semaphore_wait(sem) - pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR) + pl.run_scoped(body, pltpu.SemaphoreType.REGULAR) def body2(sem): pltpu.semaphore_signal(sem, 2) pltpu.semaphore_wait(sem) pltpu.semaphore_wait(sem) - pltpu.run_scoped(body2, pltpu.SemaphoreType.REGULAR) + pl.run_scoped(body2, pltpu.SemaphoreType.REGULAR) def body3(sem): pltpu.semaphore_signal(sem) pltpu.semaphore_signal(sem) @@ -949,7 +949,7 @@ def body3(sem): pltpu.semaphore_wait(sem) pltpu.semaphore_wait(sem) pltpu.semaphore_wait(sem) - pltpu.run_scoped(body3, pltpu.SemaphoreType.REGULAR) + pl.run_scoped(body3, pltpu.SemaphoreType.REGULAR) # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready(self.pallas_call( @@ -973,7 +973,7 @@ def body(sems): pltpu.semaphore_wait(sems.at[2]) pltpu.semaphore_wait(sems.at[2]) pltpu.semaphore_wait(sems.at[2]) - pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR((3,))) + pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((3,))) # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready(pl.pallas_call( @@ -998,7 +998,7 @@ def body(sems): pltpu.semaphore_wait(sems.at[i, 2]) pltpu.semaphore_wait(sems.at[i, 2]) pltpu.semaphore_wait(sems.at[i, 2]) - pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR((4, 3))) + pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((4, 3))) # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready( @@ -1024,7 +1024,7 @@ def body(sems): y_ref[r, c] = pltpu.semaphore_read(sems.at[r, c]) pltpu.semaphore_wait(sems.at[r, c], v) - pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR((m, n))) + pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((m, n))) # TODO(b/345534352): Add interpret support for semaphore signal/wait. y = jax.block_until_ready( @@ -1072,7 +1072,7 @@ def kernel(x_hbm_ref, y_hbm_ref): def body(sem): pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem).wait() - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) + pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, @@ -1089,7 +1089,7 @@ def kernel(x_hbm_ref, y_hbm_ref): def body(sem): pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem).wait() - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) + pl.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) # TODO(b/345534352): Add interpret support for nonscalar semaphores. with self.assertRaisesRegex(ValueError, 'Cannot signal'): @@ -1108,7 +1108,7 @@ def kernel(x_hbm_ref, y_hbm_ref): def body(sem): pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem.at[0]).wait() - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) + pl.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) x = jnp.arange(8 * 128.).reshape((8, 128)) # TODO(b/345534352): Add interpret support for nonscalar semaphores. @@ -1131,7 +1131,7 @@ def body(sem): pltpu.async_copy( x_hbm_ref.at[pl.ds(i, 1)], y_hbm_ref.at[pl.ds(i, 1)], sem ).wait() - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) + pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) y = self.pallas_call( kernel, @@ -1150,8 +1150,9 @@ def body(x_ref, sem): pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], x_ref.at[:, pl.ds(128)], sem).wait() y_ref[...] = x_ref[...] - pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32), - pltpu.SemaphoreType.DMA) + pl.run_scoped( + body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA + ) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, @@ -1167,8 +1168,9 @@ def kernel(x_ref, y_hbm_ref): def body(y_ref, sem): y_ref[...] = x_ref[...] pltpu.async_copy(y_hbm_ref, y_ref, sem).wait() - pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32), - pltpu.SemaphoreType.DMA) + pl.run_scoped( + body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA + ) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, @@ -1183,10 +1185,12 @@ def body(x_ref, y_ref, sem): pltpu.async_copy(x_hbm_ref, x_ref, sem).wait() y_ref[...] = x_ref[...] pltpu.async_copy(y_ref, y_hbm_ref, sem).wait() - pltpu.run_scoped(body, - pltpu.VMEM((8, 128), jnp.float32), - pltpu.VMEM((8, 128), jnp.float32), - pltpu.SemaphoreType.DMA) + pl.run_scoped( + body, + pltpu.VMEM((8, 128), jnp.float32), + pltpu.VMEM((8, 128), jnp.float32), + pltpu.SemaphoreType.DMA, + ) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, @@ -1201,8 +1205,9 @@ def kernel(x_hbm_ref, y_ref): def body(x_ref, sem): pltpu.async_copy(x_hbm_ref, x_ref, sem).wait() y_ref[...] = x_ref[0, 0] * jnp.ones_like(y_ref) - pltpu.run_scoped(body, pltpu.SMEM((8, 128), jnp.float32), - pltpu.SemaphoreType.DMA) + pl.run_scoped( + body, pltpu.SMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA + ) x = 4 * jnp.ones((8, 128), jnp.float32) y = self.pallas_call( kernel, @@ -1219,8 +1224,9 @@ def body(y_ref, sem): y_ref[0, 0] = 0.0 y_ref[0, 1] = x_ref[4, 4] pltpu.async_copy(y_ref, y_hbm_ref, sem).wait() - pltpu.run_scoped(body, pltpu.SMEM((1, 2), jnp.float32), - pltpu.SemaphoreType.DMA) + pl.run_scoped( + body, pltpu.SMEM((1, 2), jnp.float32), pltpu.SemaphoreType.DMA + ) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, @@ -1237,7 +1243,7 @@ def test_vmem_vmem_dma(self): def kernel(x_ref, y_ref): def body(sem): pltpu.async_copy(x_ref, y_ref, sem).wait() - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) + pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, @@ -1260,7 +1266,7 @@ def body(sem): ) dma1.wait() dma2.wait() - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) + pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((16, 128)) y = self.pallas_call( kernel, @@ -1283,7 +1289,7 @@ def body(sem): ) dma1.wait() dma2.wait() - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) + pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) y = self.pallas_call( kernel, @@ -1309,7 +1315,7 @@ def body(sem): ) dma1.wait() dma2.wait() - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) + pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(3 * 2 * 8 * 128.).reshape((3, 2, 8, 128)) y = self.pallas_call( kernel, @@ -1332,7 +1338,7 @@ def body(sem): ) dma1.wait() dma2.wait() - pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) + pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) with self.assertRaises(Exception): _ = self.pallas_call( @@ -1586,8 +1592,12 @@ def body(ready_sem, send_sem, recv_sem): copy_done.wait_send() copy_done.wait_recv() - pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR, - pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA) + pl.run_scoped( + body, + pltpu.SemaphoreType.REGULAR, + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA, + ) x = jnp.arange(2 * 8 * 128.0).reshape((2 * 8, 128)) @@ -1633,8 +1643,12 @@ def body(ready_sem, send_sem, recv_sem): copy_done.wait_send() copy_done.wait_recv() - pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR, - pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA) + pl.run_scoped( + body, + pltpu.SemaphoreType.REGULAR, + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA, + ) num_devices = jax.local_device_count() x = jnp.arange(num_devices * 8 * 128).reshape((num_devices * 8, 128)) @@ -1683,7 +1697,7 @@ def body(ready_sem, send_sem, recv_sem): copy_done.wait_send() copy_done.wait_recv() - pltpu.run_scoped( + pl.run_scoped( body, pltpu.SemaphoreType.REGULAR, pltpu.SemaphoreType.DMA, @@ -1735,8 +1749,12 @@ def body(ready_sem, send_sem, recv_sem): x_ref, y_ref, send_sem, recv_sem, device_id=neighbor ).wait() - pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR, - pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA) + pl.run_scoped( + body, + pltpu.SemaphoreType.REGULAR, + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA, + ) num_devices = jax.local_device_count() x = jnp.arange(num_devices * 8 * 128).reshape((num_devices * 8, 128)) @@ -2196,12 +2214,12 @@ def kernel(o_ref): def scope1(): with jax.named_scope('scope1'): o_ref[...] = jnp.zeros_like(o_ref[...]) - pltpu.run_scoped(scope1) + pl.run_scoped(scope1) def scope2(): with jax.named_scope('scope2'): o_ref[...] = o_ref[...] + 1 - pltpu.run_scoped(scope2) + pl.run_scoped(scope2) with string_stdout() as msg: _ = self.pallas_call( @@ -2250,8 +2268,7 @@ def kernel(x_ref, o_ref): def inner_scope(scoped_ref): scoped_ref[0, 0] = jnp.logical_not(x_ref[0, 0]) o_ref[0, 0] = scoped_ref[0, 0] - pltpu.run_scoped(inner_scope, - pltpu.SMEM((1, 1), dtype=jnp.bool_)) + pl.run_scoped(inner_scope, pltpu.SMEM((1, 1), dtype=jnp.bool_)) input_arr = jnp.array([[value]]) output_shape = jax.ShapeDtypeStruct((1, 1), jnp.bool_) result = self.pallas_call(