From b6029f98841f9d9e7fd592f9362dbca93f2a808f Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 22 Mar 2023 02:50:28 +0000 Subject: [PATCH 1/2] Add CMake Option "USE_OPT_NAVI3X" --- CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index f861e30203..c9fb6b4552 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,7 @@ include(TargetFlags) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) +option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -29,6 +30,12 @@ if(USE_BITINT_EXTENSION_INT4) message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() +if(USE_OPT_NAVI3X) + add_compile_options(-mcumode) + add_compile_options(-mno-wavefrontsize64) + message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}") +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) From 71f5dd1d150f751f0e2762bdf1c3a02f29249340 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 23 Mar 2023 02:36:37 +0000 Subject: [PATCH 2/2] fix bug --- CMakeLists.txt | 7 ------- ...n_grouped_conv_fwd_bias_relu_add_wmma_example.inc | 4 ++-- .../grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 12 ++++++++++-- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c9fb6b4552..f861e30203 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,6 @@ include(TargetFlags) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) -option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -30,12 +29,6 @@ if(USE_BITINT_EXTENSION_INT4) message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() -if(USE_OPT_NAVI3X) - add_compile_options(-mcumode) - add_compile_options(-mno-wavefrontsize64) - message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}") -endif() - ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index 8161b1088a..a6888649c0 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -74,8 +74,8 @@ using DeviceConvFwdInstance = 8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferDstScalarPerVector_BK1 true, // BBlockLdsExtraN - 1, - 1, + 4, + 2, S<1, 32, 1, 8>, 8>; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 1e8f8ff9fe..38edace197 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -431,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + constexpr auto cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + constexpr auto max_lds_align = K1; constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -439,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle constexpr auto b_block_space_size_aligned = math::integer_least_multiple( b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); - return (a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)); + constexpr auto c_block_space_size_aligned = math::integer_least_multiple( + cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize(), + max_lds_align); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_space_size_aligned * sizeof(CShuffleDataType)); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}