Skip to content

Commit

Permalink
Fix sbrc 3d erc unaligned (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
solaslin authored Aug 25, 2022
1 parent 753b1ed commit 7575816
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 28 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Full documentation for rocFFT is available at [rocfft.readthedocs.io](https://ro
### Fixed
- Fixed occasional failures to parallelize runtime compilation of kernels.
Failures would be retried serially and ultimately succeed, but this would take extra time.
- Fixed failures of some R2C 3D transforms that use the unsupported TILE_UNALGNED SBRC kernels.
An example is 98^3 R2C out-of-place.
- Fixed bugs in SBRC_ERC type.

## rocFFT 1.0.17 for ROCm 5.2.0
### Added
Expand Down
8 changes: 8 additions & 0 deletions clients/tests/accuracy_test_adhoc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ std::vector<std::vector<size_t>> adhoc_sizes = {

// Failure with build_CS_3D_BLOCK_RC
{680, 128, 128},

// TILE_UNALIGNED type of SBRC 3D ERC
{98, 98, 98},
};

const static std::vector<std::vector<size_t>> stride_range = {{1}};
Expand Down Expand Up @@ -77,6 +80,11 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_offset_adhoc,
inline auto param_permissive_iodist()
{
std::vector<std::vector<size_t>> lengths = adhoc_sizes;
// TODO- for these permissive iodist tests,
// some 98^3 sizes take too long for the exhaustive search buffer assignments
// about millions of assignments, thus the program is hung there.
// So we take this length out from iodist test for now.
lengths.erase(std::find(lengths.begin(), lengths.end(), std::vector<size_t>{98, 98, 98}));
lengths.push_back({4});

std::vector<fft_params> params;
Expand Down
4 changes: 4 additions & 0 deletions library/src/device/generator/stockham_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ std::string stockham_variants(const std::string& filename,
"CS_KERNEL_STOCKHAM_R_TO_CMPLX_TRANSPOSE_Z_XY",
"SBRC_3D_FFT_ERC_TRANS_Z_XY",
"TILE_ALIGNED"});
suffixes.push_back({"sbrc3d_fft_erc_trans_z_xy_tile_unaligned",
"CS_KERNEL_STOCKHAM_R_TO_CMPLX_TRANSPOSE_Z_XY",
"SBRC_3D_FFT_ERC_TRANS_Z_XY",
"TILE_UNALIGNED"});

output += make_launcher(specs.length,
false,
Expand Down
46 changes: 32 additions & 14 deletions library/src/device/generator/stockham_gen_rc.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ struct StockhamKernelRC : public StockhamKernel
StatementList store_to_global(bool store_registers) override
{
StatementList stmts;
StatementList regular_store;
StatementList edge_store;
StatementList non_edge_stmts;
StatementList edge_stmts;
Expression pred{tile_index_in_plane * transforms_per_block + tid_hor < len_along_block};

if(!store_registers)
Expand All @@ -461,6 +461,8 @@ struct StockhamKernelRC : public StockhamKernel
stmts += Assign{tid_hor, thread_id % store_block_w};
}

StatementList regular_store;

// we need to take care about two diff cases for offset in buf and lds
// divisible: each store leads to a perfect block: update offset much simpler
// indivisible: need extra div and mod, otherwise each store will have some elements un-set:
Expand All @@ -485,12 +487,24 @@ struct StockhamKernelRC : public StockhamKernel

// ERC_Z_XY
auto i = num_store_blocks;
StatementList stmts_erc_post;
stmts_erc_post += If{
StatementList stmts_erc_post_no_edge;
stmts_erc_post_no_edge
+= CommentLines{"extra global write for SBRC_3D_FFT_ERC_TRANS_Z_XY"};
stmts_erc_post_no_edge += If{
thread_id < transforms_per_block,
{StoreGlobal{buf, offset + offset_tile_wbuf(i), lds_complex[offset_tile_rlds(i)]}}};
regular_store += CommentLines{"extra global write for SBRC_3D_FFT_ERC_TRANS_Z_XY"};
regular_store += If{sbrc_type == "SBRC_3D_FFT_ERC_TRANS_Z_XY", stmts_erc_post};
non_edge_stmts += regular_store;
non_edge_stmts += If{sbrc_type == "SBRC_3D_FFT_ERC_TRANS_Z_XY", stmts_erc_post_no_edge};

StatementList stmts_erc_post_edge;
stmts_erc_post_edge
+= CommentLines{"extra global write for SBRC_3D_FFT_ERC_TRANS_Z_XY"};
stmts_erc_post_edge += If{thread_id < Parens{len_along_block % transforms_per_block},
{StoreGlobal{buf,
offset + offset_tile_wbuf(i),
lds_complex[tid_hor * stride_lds + length]}}};
edge_stmts += regular_store;
edge_stmts += If{sbrc_type == "SBRC_3D_FFT_ERC_TRANS_Z_XY", stmts_erc_post_edge};
}
else
{
Expand All @@ -503,14 +517,14 @@ struct StockhamKernelRC : public StockhamKernel
auto height = static_cast<float>(length) / width / threads_per_transform;

auto store_global = std::mem_fn(&StockhamKernelRC::store_global_generator);
regular_store += add_work(
non_edge_stmts += add_work(
std::bind(store_global, this, _1, _2, _3, _4, cumheight), width, height, true);
}

edge_store += If{pred, regular_store};
edge_stmts = non_edge_stmts;
}

stmts += If{Or{transpose_type != "TILE_UNALIGNED", Not{edge}}, regular_store};
stmts += Else{edge_store};
stmts += If{Or{transpose_type != "TILE_UNALIGNED", Not{edge}}, non_edge_stmts};
stmts += Else{{If{pred, edge_stmts}}};

return stmts;
}
Expand Down Expand Up @@ -548,17 +562,21 @@ struct StockhamKernelRC : public StockhamKernel
// Todo: We might not have to sync here which depends on the access pattern
stmts += SyncThreads{};
stmts += LineBreak{};

// length is Half_N, remember quarter_N should be is (Half_N + 1) / 2
// And need to set the Ndiv4 template argument
Variable Ndiv4{length % 2 == 0 ? "true" : "false", "bool"};
for(unsigned int h = 0; h < transforms_per_block; ++h)
{
stmts += Call{"post_process_interleaved_inplace",
{scalar_type, Variable{"true", ""}, Variable{"CallbackType::NONE", ""}},
{scalar_type, Ndiv4, Variable{"CallbackType::NONE", ""}},
{thread_id,
length - thread_id,
length,
length / 2,
(length + 1) / 2,
lds_complex + (h * stride_lds),
0,
twiddles + length - factors.front(),
twiddles + (length - factors.front()),
null,
null,
0,
Expand Down
5 changes: 4 additions & 1 deletion library/src/fuse_shim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,10 @@ std::unique_ptr<TreeNode> STK_R2CTrans_FuseShim::FuseKernels()
auto fused = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM_R_TO_CMPLX_TRANSPOSE_Z_XY,
stockham->parent);
fused->CopyNodeData(*stockham);
// no need to check kernel exists, this scheme uses a built-in kernel
// check if kernel exists, since the fused kernel uses different scheme other than stockham
if(!fused->KernelCheck())
return nullptr;

fused->placement = rocfft_placement_notinplace;
fused->outArrayType = transpose->outArrayType;
fused->obOut = transpose->obOut;
Expand Down
5 changes: 3 additions & 2 deletions library/src/include/function_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ static inline FMKey fpkey(size_t length1,
return {{length1, length2}, precision, scheme, transpose};
}

inline void PrintMissingKernelInfo(const FMKey& key)
inline std::string PrintMissingKernelInfo(const FMKey& key)
{
const auto& lengthVec = std::get<0>(key);
const rocfft_precision precision = std::get<1>(key);
Expand All @@ -60,7 +60,8 @@ inline void PrintMissingKernelInfo(const FMKey& key)
<< "\tprecision: " << precision << "\n"
<< "\tscheme: " << PrintScheme(scheme) << "\n"
<< "\tSBRC Transpose type: " << PrintSBRCTransposeType(trans) << std::endl;
throw std::runtime_error(msg.str());

return msg.str();
}

struct SimpleHash
Expand Down
5 changes: 4 additions & 1 deletion library/src/tree_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "tree_node.h"
#include "function_pool.h"
#include "kernel_launch.h"
#include "logging.h"
#include "repo.h"
#include "twiddles.h"

Expand Down Expand Up @@ -114,7 +115,9 @@ bool LeafNode::KernelCheck()
: fpkey(length[0], length[1], precision, scheme);
if(!function_pool::has_function(key))
{
PrintMissingKernelInfo(key);
if(LOG_TRACE_ENABLED())
(*LogSingleton::GetInstance().GetTraceOS()) << PrintMissingKernelInfo(key);

return false;
}

Expand Down
24 changes: 14 additions & 10 deletions library/src/tree_node_3D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "tree_node_3D.h"
#include "arithmetic.h"
#include "function_pool.h"
#include "logging.h"
#include "node_factory.h"
#include <numeric>

Expand Down Expand Up @@ -723,23 +724,26 @@ void RC3DNode::AssignParams_internal()
*****************************************************/
bool SBRCTranspose3DNode::KernelCheck()
{
// check we have the kernel
// TODO: TILE_UNALIGNED if we have it
// check we have the kernel,
// we always have aligned, get the kernel and the bwd
FMKey key = fpkey(length[0], precision, scheme, TILE_ALIGNED);
if(!function_pool::has_function(key))
{
PrintMissingKernelInfo(key);
if(LOG_TRACE_ENABLED())
(*LogSingleton::GetInstance().GetTraceOS()) << PrintMissingKernelInfo(key);
return false;
}

if(is_diagonal_sbrc_3D_length(length[0]) && is_cube_size(length))
auto bwd = function_pool::get_kernel(key).transforms_per_block;

// check if we have the sbrc_type that we are actually applying
sbrcTranstype = sbrc_transpose_type(bwd);
if(!function_pool::has_function(fpkey(length[0], precision, scheme, sbrcTranstype)))
{
key = fpkey(length[0], precision, scheme, DIAGONAL);
if(!function_pool::has_function(key))
{
PrintMissingKernelInfo(key);
return false;
}
if(LOG_TRACE_ENABLED())
(*LogSingleton::GetInstance().GetTraceOS())
<< PrintMissingKernelInfo(fpkey(length[0], precision, scheme, sbrcTranstype));
return false;
}

dir2regMode = (function_pool::get_kernel(key).direct_to_from_reg)
Expand Down

0 comments on commit 7575816

Please sign in to comment.