From 670980bc3ee167a08d4e4493f6d4e5c980c436e4 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 21 Mar 2024 21:41:20 -0400 Subject: [PATCH 1/3] Do not pass target explicitly to FP8 legalization, use BindTarget instead --- include/tvm/tir/transform.h | 8 ++++---- python/tvm/tir/transform/transform.py | 16 ++++------------ src/driver/driver_api.cc | 7 +++---- src/tir/transforms/unsupported_dtype_legalize.cc | 6 ++++-- .../test_tir_transform_fp8_legalize.py | 15 ++++++++------- 5 files changed, 23 insertions(+), 29 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e219cc684657..98edbeaceb26 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -398,7 +398,6 @@ TVM_DLL Pass ForceNarrowIndexToInt32(); /*! * \brief Legalize bf16 compute Ops. Add a cast to fp32 * before Ops, then add a cast back to bf16. - * \param target The target used for checking native bf16 support * \return The pass. */ TVM_DLL Pass BF16ComputeLegalize(); @@ -406,11 +405,11 @@ TVM_DLL Pass BF16ComputeLegalize(); /*! * \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32 * before Ops, then add a cast back to fp8. - * \param target The target used for checking native fp8 support * \param promote_dtype_str The data type used for type promotion, defaults to float16 + * \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs * \return The pass. */ -TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str = "float16"); +TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16"); /*! * \brief Legalize bf16 storage types to u16. @@ -420,9 +419,10 @@ TVM_DLL Pass BF16StorageLegalize(); /*! * \brief Legalize fp8 storage types to u8. + * \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs * \return The pass. */ -TVM_DLL Pass FP8StorageLegalize(Target target); +TVM_DLL Pass FP8StorageLegalize(); /*! * \brief Inline calls to private functions diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 9f7f92dbed74..e41f6509c50b 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -323,7 +323,7 @@ def BF16ComputeLegalize(): return _ffi_api.BF16ComputeLegalize() # type: ignore -def FP8ComputeLegalize(target: Any, promote_dtype_str: str = "float32"): +def FP8ComputeLegalize(promote_dtype_str: str = "float32"): """Legalize fp8 compute Ops. Parameters @@ -331,15 +331,12 @@ def FP8ComputeLegalize(target: Any, promote_dtype_str: str = "float32"): promote_dtype : str The data type we promote fp8 to, options: float16/float32. - target : tvm.target.Target - The legalization target - Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.FP8ComputeLegalize(target, promote_dtype_str) # type: ignore + return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore def BF16StorageLegalize(): @@ -353,20 +350,15 @@ def BF16StorageLegalize(): return _ffi_api.BF16StorageLegalize() # type: ignore -def FP8StorageLegalize(target: Any): +def FP8StorageLegalize(): """Legalize fp8 storage types to u8. - Parameters - ---------- - target : tvm.target.Target - The legalization target - Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.FP8StorageLegalize(target) # type: ignore + return _ffi_api.FP8StorageLegalize() # type: ignore def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e3b4a5a6517c..ba640720aab0 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -569,15 +569,14 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) Array mixed_pass_list; - mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target)); + mixed_pass_list.push_back(tir::transform::BindTarget(target)); + mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); // VerifyVTCMLimit must occur before LowerVtcmAlloc mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target)); // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc()); - mixed_pass_list.push_back(tir::transform::BindTarget(target)); - mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); @@ -618,7 +617,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } - mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target)); + mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index c0378790740f..5537c8a409a0 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -727,8 +727,9 @@ Pass BF16StorageLegalize() { TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize); -Pass FP8ComputeLegalize(Target target, String promote_dtype_str) { +Pass FP8ComputeLegalize(String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget).value(); if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) { return f; } @@ -739,8 +740,9 @@ Pass FP8ComputeLegalize(Target target, String promote_dtype_str) { TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize); -Pass FP8StorageLegalize(Target target) { +Pass FP8StorageLegalize() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget).value(); if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) { return f; } diff --git a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py index 6e44b53d0cae..e1f487c572df 100644 --- a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py @@ -19,6 +19,7 @@ import tvm.testing from tvm.target import Target from tvm.script import tir as T +from tvm.tir.transform.transform import BindTarget # pylint: disable=no-member,invalid-name,unused-variable @@ -206,20 +207,20 @@ def main(Aptr: T.handle("uint8"), Bptr: T.handle("uint8"), Dptr: T.handle("uint8 def test_fp8_compute_legalize(dtype, promote_dtype): target = Target("cuda") - before = get_before(dtype) - expected = get_after_compute_legalize(dtype, promote_dtype) + before = BindTarget(target)(get_before(dtype)) + expected = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype)) # run the transform twice to ensure we can afford to deal # with this repeative optimizations - after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(before) - after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(after) + after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(before) + after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(after) tvm.ir.assert_structural_equal(after, expected) def test_fp8_storage_legalize(dtype, promote_dtype): target = Target("cuda") - before = get_after_compute_legalize(dtype, promote_dtype) - after = tvm.tir.transform.FP8StorageLegalize(target)(before) - expected = get_after_storage_legalize(dtype, promote_dtype) + before = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype)) + after = tvm.tir.transform.FP8StorageLegalize()(before) + expected = BindTarget(target)(get_after_storage_legalize(dtype, promote_dtype)) tvm.ir.assert_structural_equal(after, expected) From f47bdbaf430f97845ab3ba2fc0258906b50f88bc Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 22 Mar 2024 12:57:51 -0400 Subject: [PATCH 2/3] Lint: Remove unused import --- python/tvm/tir/transform/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index e41f6509c50b..c2022b918643 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -19,7 +19,7 @@ import enum -from typing import Any, Callable, Optional +from typing import Callable, Optional from . import _ffi_api from . import function_pass as _fpass From d787fd4cb6aaa85c60073363f4d9ab18f04f43f6 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 22 Mar 2024 12:58:04 -0400 Subject: [PATCH 3/3] Add comment on pass ordering --- src/driver/driver_api.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index ba640720aab0..622010988340 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -569,6 +569,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) Array mixed_pass_list; + // FPComputeLegalize uses the target attrs added by BindTarget, so it must come first mixed_pass_list.push_back(tir::transform::BindTarget(target)); mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize());