Skip to content

Commit

Permalink
[TIR][Driver] Use BindTarget to specify target for FP8 legalization (
Browse files Browse the repository at this point in the history
…apache#16767)

* Do not pass target explicitly to FP8 legalization, use BindTarget instead

* Lint: Remove unused import

* Add comment on pass ordering
  • Loading branch information
slyubomirsky authored and thaisacs committed Apr 3, 2024
1 parent 432e847 commit 8411756
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 30 deletions.
8 changes: 4 additions & 4 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,19 +398,18 @@ TVM_DLL Pass ForceNarrowIndexToInt32();
/*!
* \brief Legalize bf16 compute Ops. Add a cast to fp32
* before Ops, then add a cast back to bf16.
* \param target The target used for checking native bf16 support
* \return The pass.
*/
TVM_DLL Pass BF16ComputeLegalize();

/*!
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
* before Ops, then add a cast back to fp8.
* \param target The target used for checking native fp8 support
* \param promote_dtype_str The data type used for type promotion, defaults to float16
* \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs
* \return The pass.
*/
TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str = "float16");
TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");

/*!
* \brief Legalize bf16 storage types to u16.
Expand All @@ -420,9 +419,10 @@ TVM_DLL Pass BF16StorageLegalize();

/*!
* \brief Legalize fp8 storage types to u8.
* \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs
* \return The pass.
*/
TVM_DLL Pass FP8StorageLegalize(Target target);
TVM_DLL Pass FP8StorageLegalize();

/*!
* \brief Inline calls to private functions
Expand Down
18 changes: 5 additions & 13 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


import enum
from typing import Any, Callable, Optional
from typing import Callable, Optional

from . import _ffi_api
from . import function_pass as _fpass
Expand Down Expand Up @@ -323,23 +323,20 @@ def BF16ComputeLegalize():
return _ffi_api.BF16ComputeLegalize() # type: ignore


def FP8ComputeLegalize(target: Any, promote_dtype_str: str = "float32"):
def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
"""Legalize fp8 compute Ops.
Parameters
----------
promote_dtype : str
The data type we promote fp8 to, options: float16/float32.
target : tvm.target.Target
The legalization target
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FP8ComputeLegalize(target, promote_dtype_str) # type: ignore
return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore


def BF16StorageLegalize():
Expand All @@ -353,20 +350,15 @@ def BF16StorageLegalize():
return _ffi_api.BF16StorageLegalize() # type: ignore


def FP8StorageLegalize(target: Any):
def FP8StorageLegalize():
"""Legalize fp8 storage types to u8.
Parameters
----------
target : tvm.target.Target
The legalization target
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FP8StorageLegalize(target) # type: ignore
return _ffi_api.FP8StorageLegalize() # type: ignore


def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False):
Expand Down
8 changes: 4 additions & 4 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,15 +569,15 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

Array<Pass> mixed_pass_list;

mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target));
// FPComputeLegalize uses the target attrs added by BindTarget, so it must come first
mixed_pass_list.push_back(tir::transform::BindTarget(target));
mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize());

// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());

mixed_pass_list.push_back(tir::transform::BindTarget(target));

mixed_pass_list.push_back(tir::transform::VerifyMemory());

mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
Expand Down Expand Up @@ -620,7 +620,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target));
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
Expand Down
6 changes: 4 additions & 2 deletions src/tir/transforms/unsupported_dtype_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,9 @@ Pass BF16StorageLegalize() {

TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize);

Pass FP8ComputeLegalize(Target target, String promote_dtype_str) {
Pass FP8ComputeLegalize(String promote_dtype_str) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
return f;
}
Expand All @@ -739,8 +740,9 @@ Pass FP8ComputeLegalize(Target target, String promote_dtype_str) {

TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize);

Pass FP8StorageLegalize(Target target) {
Pass FP8StorageLegalize() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
return f;
}
Expand Down
15 changes: 8 additions & 7 deletions tests/python/tir-transform/test_tir_transform_fp8_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm.testing
from tvm.target import Target
from tvm.script import tir as T
from tvm.tir.transform.transform import BindTarget

# pylint: disable=no-member,invalid-name,unused-variable

Expand Down Expand Up @@ -206,20 +207,20 @@ def main(Aptr: T.handle("uint8"), Bptr: T.handle("uint8"), Dptr: T.handle("uint8

def test_fp8_compute_legalize(dtype, promote_dtype):
target = Target("cuda")
before = get_before(dtype)
expected = get_after_compute_legalize(dtype, promote_dtype)
before = BindTarget(target)(get_before(dtype))
expected = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype))
# run the transform twice to ensure we can afford to deal
# with this repeative optimizations
after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(before)
after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(after)
after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(before)
after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(after)
tvm.ir.assert_structural_equal(after, expected)


def test_fp8_storage_legalize(dtype, promote_dtype):
target = Target("cuda")
before = get_after_compute_legalize(dtype, promote_dtype)
after = tvm.tir.transform.FP8StorageLegalize(target)(before)
expected = get_after_storage_legalize(dtype, promote_dtype)
before = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype))
after = tvm.tir.transform.FP8StorageLegalize()(before)
expected = BindTarget(target)(get_after_storage_legalize(dtype, promote_dtype))
tvm.ir.assert_structural_equal(after, expected)


Expand Down

0 comments on commit 8411756

Please sign in to comment.