Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conv bwd weight fp16 comp bf8 fp8 op, instances and example #945

Merged
merged 43 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
c3fdcaf
Add f8 bf8 gemm example
geyyer Sep 20, 2023
4d90c0d
Add element-wise ops
geyyer Sep 20, 2023
7cf5d8f
Add intrinsics
geyyer Sep 20, 2023
8038fad
Update reference calculation
geyyer Sep 20, 2023
7c45a65
Add an additional type option for xdlops gemm
geyyer Sep 20, 2023
78bfffb
Fix build process
geyyer Sep 20, 2023
ead87d7
Add bf8 to buffer addressing
geyyer Sep 20, 2023
2ab1dbf
Update blockwise op, split typeA and typeB
geyyer Sep 20, 2023
4cb6395
Update for compatibility
geyyer Sep 20, 2023
9b7093a
Merge branch 'develop' into lwpck-791
geyyer Sep 21, 2023
2034488
Uppdate naming to f8->fp8
geyyer Sep 21, 2023
54f2bb9
Update naming
geyyer Sep 21, 2023
cc3862e
Format
geyyer Sep 21, 2023
879f9c8
Update naming (#937)
geyyer Sep 22, 2023
ed8d62d
Add a client example
geyyer Sep 26, 2023
196da4b
Add computetypes to device and gridwise ops
geyyer Sep 26, 2023
337c076
Add instances, update instance factory
geyyer Sep 26, 2023
1664978
Format
geyyer Sep 26, 2023
7f5f03e
Fix a flag
geyyer Sep 26, 2023
9f06049
Add ckProfiler mode
geyyer Sep 27, 2023
1504090
Fix typos
geyyer Sep 28, 2023
b19f862
Add an example
geyyer Sep 28, 2023
2150167
Add bf8 generator
geyyer Sep 28, 2023
bfc207f
add bf8 mfma; fixed type_convert for bf8
Sep 28, 2023
9b882a7
Merge branch 'lwpck-865' of github.com:ROCmSoftwarePlatform/composabl…
Sep 28, 2023
ab7b938
move verfication ahead of timing
Sep 28, 2023
aaca450
Update reference calculation
geyyer Sep 29, 2023
990c020
Fix reference
geyyer Sep 29, 2023
e06fe14
Narrow down float init range
geyyer Sep 29, 2023
88e033c
Fix bf8 bf8 mfma
geyyer Sep 29, 2023
5a7185a
Add bf8 @ fp8 mfma
geyyer Sep 29, 2023
df9ce71
Update example
geyyer Sep 29, 2023
3cc2cb1
Update instances
geyyer Oct 2, 2023
0aef5b9
Update profiler api
geyyer Oct 2, 2023
fe95e76
Merge branch 'develop' into lwpck-865
geyyer Oct 2, 2023
6849987
Update for compatibility
geyyer Oct 2, 2023
e445a8d
Format
geyyer Oct 2, 2023
74085e6
Remove extra example
geyyer Oct 2, 2023
d4855e7
Merge branch 'develop' into lwpck-865
geyyer Oct 2, 2023
ce6e3c3
Clean up
geyyer Oct 3, 2023
ccdadb1
Merge branch 'develop' into lwpck-865
geyyer Oct 3, 2023
df86d6d
workaround convert
Oct 4, 2023
092b916
Merge branch 'lwpck-865' of github.com:ROCmSoftwarePlatform/composabl…
Oct 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions example/20_grouped_conv_bwd_weight/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
endif()
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
endif()
endif()
set(target 1)
endif()
endforeach()
Expand Down
6 changes: 6 additions & 0 deletions example/20_grouped_conv_bwd_weight/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
#ifdef CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedC
5, // CThreadTransferSrcDstVectorDim
4>; // CThreadTransferDstScalarPerVector

template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;

#include "run_grouped_conv_bwd_weight_example.inc"

int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ using DeviceConvBwdWeightInstance =
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl

template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;

#include "run_grouped_conv_bwd_weight_example.inc"

int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ using DeviceConvBwdWeightInstance =
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl

template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;

#include "run_grouped_conv_bwd_weight_example.inc"

int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#include "common.hpp"

#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"

using InDataType = F16;
using WeiDataType = F16;
using OutDataType = F16;
using AccDataType = F32;
using ComputeTypeA = BF8;
using ComputeTypeB = F8;

using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = PassThrough;

template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>,
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
1, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
1, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2, // CBlockTransferScalarPerVector_NWaveNPerXdl
ComputeTypeA, // ComputeTypeA
ComputeTypeB>; // ComputeTypeB

template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ComputeTypeA,
ComputeTypeB>;

#include "run_grouped_conv_bwd_weight_example.inc"

int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;

template <ck::index_t NDimSpatial>
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param)
Expand Down Expand Up @@ -46,8 +37,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 0.2});
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.1, 0.1});
}

DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
Expand Down Expand Up @@ -113,18 +104,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
return true;
}

float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});

std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();

float tflops = static_cast<float>(flop) / 1.E9 / avg_time;

float gb_per_sec = num_btype / 1.E6 / avg_time;

std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl
<< "DeviceOp: " << conv.GetTypeString() << std::endl;
invoker.Run(argument, StreamConfig{nullptr, false});

if(config.do_verification)
{
Expand All @@ -148,6 +128,19 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData);
}

float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});

std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();

float tflops = static_cast<float>(flop) / 1.E9 / avg_time;

float gb_per_sec = num_btype / 1.E6 / avg_time;

std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl
<< "DeviceOp: " << conv.GetTypeString() << std::endl;

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ template <ck::index_t NDimSpatial,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
typename OutElementwiseOperation,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedConvBwdWeight : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ struct ComputePtrOffsetOfStridedBatch
} // namespace

template <typename GridwiseGemm,
typename FloatAB,
typename FloatA,
typename FloatB,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
Expand All @@ -64,8 +65,8 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
Expand All @@ -91,7 +92,7 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));

__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];

GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
Expand Down Expand Up @@ -163,7 +164,9 @@ template <ck::index_t NDimSpatial,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
Expand All @@ -174,7 +177,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
OutElementwiseOperation,
ComputeTypeA,
ComputeTypeB>
{
using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle;

Expand Down Expand Up @@ -1045,7 +1050,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle

using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
ADataType,
BDataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
Expand Down Expand Up @@ -1090,7 +1096,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
true,
1,
PipelineVersion::v1,
ComputeTypeA,
ComputeTypeB>;

// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
Expand Down Expand Up @@ -1217,8 +1227,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
index_t M01_;
index_t N01_;

InElementwiseOperation a_element_op_;
OutElementwiseOperation b_element_op_;
OutElementwiseOperation a_element_op_;
InElementwiseOperation b_element_op_;
WeiElementwiseOperation c_element_op_;

// for checking IsSupportedArgument()
Expand Down Expand Up @@ -1281,7 +1291,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle

const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
ADataType,
BDataType,
CDataType,
OutElementwiseOperation,
InElementwiseOperation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ struct PassThrough
template <>
__host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const
{
y = type_convert<bf8_t>(x);
// to-do: fix half_t to bf8_t convert
y = ck::type_convert<bf8_t>(ck::type_convert<float>(x));
}
#endif
};
Expand Down
Loading