Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Driver] Use BindTarget to specify target for FP8 legalization #16767

Merged
merged 3 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,19 +398,18 @@ TVM_DLL Pass ForceNarrowIndexToInt32();
/*!
* \brief Legalize bf16 compute Ops. Add a cast to fp32
* before Ops, then add a cast back to bf16.
* \param target The target used for checking native bf16 support
* \return The pass.
*/
TVM_DLL Pass BF16ComputeLegalize();

/*!
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
* before Ops, then add a cast back to fp8.
* \param target The target used for checking native fp8 support
* \param promote_dtype_str The data type used for type promotion, defaults to float16
* \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs
* \return The pass.
*/
TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str = "float16");
TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");

/*!
* \brief Legalize bf16 storage types to u16.
Expand All @@ -420,9 +419,10 @@ TVM_DLL Pass BF16StorageLegalize();

/*!
* \brief Legalize fp8 storage types to u8.
* \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs
* \return The pass.
*/
TVM_DLL Pass FP8StorageLegalize(Target target);
TVM_DLL Pass FP8StorageLegalize();

/*!
* \brief Inline calls to private functions
Expand Down
16 changes: 4 additions & 12 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,23 +323,20 @@ def BF16ComputeLegalize():
return _ffi_api.BF16ComputeLegalize() # type: ignore


def FP8ComputeLegalize(target: Any, promote_dtype_str: str = "float32"):
def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
"""Legalize fp8 compute Ops.

Parameters
----------
promote_dtype : str
The data type we promote fp8 to, options: float16/float32.

target : tvm.target.Target
The legalization target

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FP8ComputeLegalize(target, promote_dtype_str) # type: ignore
return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore


def BF16StorageLegalize():
Expand All @@ -353,20 +350,15 @@ def BF16StorageLegalize():
return _ffi_api.BF16StorageLegalize() # type: ignore


def FP8StorageLegalize(target: Any):
def FP8StorageLegalize():
"""Legalize fp8 storage types to u8.

Parameters
----------
target : tvm.target.Target
The legalization target

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FP8StorageLegalize(target) # type: ignore
return _ffi_api.FP8StorageLegalize() # type: ignore


def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False):
Expand Down
7 changes: 3 additions & 4 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,15 +569,14 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

Array<Pass> mixed_pass_list;

mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target));
mixed_pass_list.push_back(tir::transform::BindTarget(target));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment here specifying that the BindTarget should occur first in this sequence, so that later passes can rely on the target attribute being present? (It looks like both VerifyVTCMLimit and LowerVtcmAlloc were implemented more recently than BindTarget, and probably should have been placed after BindTarget at that point.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah will do

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but also, in principle, allows for PrimFuncs with different targets to coexist and be handled correctly by the passes.

Here BindTarget is applying the target to all functions in the mixed IRModule, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's what it does. In principle, we could set the attributes differently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight nitpick: All functions that do not already have the target attribute. A function could already have the target attribute, such as defining a module that contains kernels for multiple different targets.

mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize());

// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());

mixed_pass_list.push_back(tir::transform::BindTarget(target));

mixed_pass_list.push_back(tir::transform::VerifyMemory());

mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
Expand Down Expand Up @@ -618,7 +617,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target));
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
Expand Down
6 changes: 4 additions & 2 deletions src/tir/transforms/unsupported_dtype_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,9 @@ Pass BF16StorageLegalize() {

TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize);

Pass FP8ComputeLegalize(Target target, String promote_dtype_str) {
Pass FP8ComputeLegalize(String promote_dtype_str) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
return f;
}
Expand All @@ -739,8 +740,9 @@ Pass FP8ComputeLegalize(Target target, String promote_dtype_str) {

TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize);

Pass FP8StorageLegalize(Target target) {
Pass FP8StorageLegalize() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
return f;
}
Expand Down
15 changes: 8 additions & 7 deletions tests/python/tir-transform/test_tir_transform_fp8_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm.testing
from tvm.target import Target
from tvm.script import tir as T
from tvm.tir.transform.transform import BindTarget

# pylint: disable=no-member,invalid-name,unused-variable

Expand Down Expand Up @@ -206,20 +207,20 @@ def main(Aptr: T.handle("uint8"), Bptr: T.handle("uint8"), Dptr: T.handle("uint8

def test_fp8_compute_legalize(dtype, promote_dtype):
target = Target("cuda")
before = get_before(dtype)
expected = get_after_compute_legalize(dtype, promote_dtype)
before = BindTarget(target)(get_before(dtype))
expected = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype))
# run the transform twice to ensure we can afford to deal
# with this repeative optimizations
after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(before)
after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(after)
after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(before)
after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(after)
tvm.ir.assert_structural_equal(after, expected)


def test_fp8_storage_legalize(dtype, promote_dtype):
target = Target("cuda")
before = get_after_compute_legalize(dtype, promote_dtype)
after = tvm.tir.transform.FP8StorageLegalize(target)(before)
expected = get_after_storage_legalize(dtype, promote_dtype)
before = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype))
after = tvm.tir.transform.FP8StorageLegalize()(before)
expected = BindTarget(target)(get_after_storage_legalize(dtype, promote_dtype))
tvm.ir.assert_structural_equal(after, expected)


Expand Down
Loading