Skip to content

Commit

Permalink
Make old sbcc kernels compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
solaslin authored Jun 18, 2021
1 parent 2bd667b commit b93c40c
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 59 deletions.
48 changes: 27 additions & 21 deletions library/src/device/generator/generator.argument.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,65 +38,66 @@ struct support_bitwise_enum<EPredefineType> : std::true_type
class generator_argument
{
public:
size_t group_num = 8;
EPrecision precision = EPrecision::ALL;
EPredefineType predefineType = EPredefineType::ALL;
std::vector<size_t> manualSize;
std::vector<size_t> manualSizeLarge;
std::set<size_t> validManualSize;
std::set<size_t> validManualSizeLarge;

void init_precision(const std::vector<std::string>& argString)
size_t group_num = 8;
EPrecision precision = EPrecision::ALL;
EPredefineType predefineType = EPredefineType::ALL;
std::set<size_t> manualSize;
std::set<size_t> manualSizeLarge;
std::set<size_t> largeSizesWithoutSBCC;
std::set<size_t> validManualSize;
std::set<size_t> validManualSizeLarge;

void init_precision(const std::set<std::string>& argString)
{
// we're here only when -p is in the args, starting from none and do bit-OR
precision = EPrecision::NONE;
if(std::find(argString.begin(), argString.end(), "single") != argString.end())
if(argString.count("single"))
{
precision |= EPrecision::SINGLE;
}
if(std::find(argString.begin(), argString.end(), "double") != argString.end())
if(argString.count("double"))
{
precision |= EPrecision::DOUBLE;
}
if(std::find(argString.begin(), argString.end(), "all") != argString.end())
if(argString.count("all"))
{
precision |= EPrecision::ALL;
}
}

void init_type(const std::vector<std::string>& argString)
void init_type(const std::set<std::string>& argString)
{
// we're here only when -t is in the args, starting from none and do bit-OR
predefineType = EPredefineType::NONE;
if(std::find(argString.begin(), argString.end(), "pow2") != argString.end())
if(argString.count("pow2"))
{
predefineType |= EPredefineType::POW2;
}
if(std::find(argString.begin(), argString.end(), "pow3") != argString.end())
if(argString.count("pow3"))
{
predefineType |= EPredefineType::POW3;
}
if(std::find(argString.begin(), argString.end(), "pow5") != argString.end())
if(argString.count("pow5"))
{
predefineType |= EPredefineType::POW5;
}
if(std::find(argString.begin(), argString.end(), "pow7") != argString.end())
if(argString.count("pow7"))
{
predefineType |= EPredefineType::POW7;
}
if(std::find(argString.begin(), argString.end(), "small") != argString.end())
if(argString.count("small"))
{
predefineType |= EPredefineType::SMALL;
}
if(std::find(argString.begin(), argString.end(), "large") != argString.end())
if(argString.count("large"))
{
predefineType |= EPredefineType::LARGE;
}
if(std::find(argString.begin(), argString.end(), "2D") != argString.end())
if(argString.count("2D"))
{
predefineType |= EPredefineType::DIM2;
}
if(std::find(argString.begin(), argString.end(), "all") != argString.end())
if(argString.count("all"))
{
predefineType |= EPredefineType::ALL;
}
Expand Down Expand Up @@ -198,6 +199,11 @@ class generator_argument
ss << " " << i;
ss << separator;

ss << "don't gen sbcc for large size:";
for(auto i : largeSizesWithoutSBCC)
ss << " " << i;
ss << separator;

ss << "precision:";
if(has_precision(EPrecision::SINGLE))
ss << " single";
Expand Down
2 changes: 1 addition & 1 deletion library/src/device/generator/generator.kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@ namespace StockhamGenerator
// Function signature
if(NeedsLargeTwiddles())
{
str += "template <typename T, StrideBin sb, bool TwdLarge, CallbackType cbtype, "
str += "template <typename T, StrideBin sb, CallbackType cbtype, bool TwdLarge, "
"size_t LTBase="
+ std::to_string(LTWD_BASE_DEFAULT) + ">\n";
}
Expand Down
25 changes: 17 additions & 8 deletions library/src/device/generator/generator.main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ int main(int argc, char* argv[])
std::string precisionArgStr;
std::string manualSmallArgStr;
std::string manualLargeArgStr;
std::string noSBCCArgStr;

std::vector<std::string> precisionList;
std::vector<std::string> typeList;
std::set<std::string> argStrList;

// clang-format doesn't handle boost program options very well:
// clang-format off
Expand All @@ -168,6 +168,8 @@ int main(int argc, char* argv[])
"Manual 1D small sizes(Separate by comma)")
("manual-large", value<std::string>(&manualLargeArgStr),
"Manual 1D large sizes(Separate by comma)")
("no-sbcc", value<std::string>(&noSBCCArgStr),
"gen large sizes with sbrc only, no sbcc (Separate by comma)")
("group,g", value<size_t>(&argument.group_num)->default_value(8),
"Numbers of kernel launch cpp files for 1D small size");
// clang-format on
Expand Down Expand Up @@ -201,14 +203,19 @@ int main(int argc, char* argv[])
// default type is ALL if not specified, else init_type from arg
if(vm.count("type"))
{
parse_arg_strings(typeArgStr, typeList);
argument.init_type(typeList);
parse_arg_strings(typeArgStr, argStrList);
argument.init_type(argStrList);
}
// default precision is ALL if not specified, else init_precision from arg
if(vm.count("precision"))
{
parse_arg_strings(precisionArgStr, precisionList);
argument.init_precision(precisionList);
parse_arg_strings(precisionArgStr, argStrList);
argument.init_precision(argStrList);
}
// default large sizes gen both sbcc and sbrc, except for those tagged with "no-sbcc"
if(vm.count("no-sbcc"))
{
parse_arg_ints(noSBCCArgStr, argument.largeSizesWithoutSBCC);
}

if(argument.group_num <= 0)
Expand Down Expand Up @@ -293,7 +300,8 @@ int main(int argc, char* argv[])
{
for(auto i : supported_large_set)
{
// large1D_list.push_back(std::make_tuple(i, CS_KERNEL_STOCKHAM_BLOCK_CC));
if(argument.largeSizesWithoutSBCC.count(i) == 0)
large1D_list.push_back(std::make_tuple(i, CS_KERNEL_STOCKHAM_BLOCK_CC));
large1D_list.push_back(std::make_tuple(i, CS_KERNEL_STOCKHAM_BLOCK_RC));
}
}
Expand All @@ -304,7 +312,8 @@ int main(int argc, char* argv[])
{
for(auto i : argument.validManualSizeLarge)
{
// large1D_list.push_back(std::make_tuple(i, CS_KERNEL_STOCKHAM_BLOCK_CC));
if(argument.largeSizesWithoutSBCC.count(i) == 0)
large1D_list.push_back(std::make_tuple(i, CS_KERNEL_STOCKHAM_BLOCK_CC));
large1D_list.push_back(std::make_tuple(i, CS_KERNEL_STOCKHAM_BLOCK_RC));
}
}
Expand Down
14 changes: 9 additions & 5 deletions library/src/device/generator/generator.options_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,24 +384,28 @@ class parse_command_line
// We can define the notify() function as a no-op for our purposes
inline void notify(const variables_map&) {}

void parse_arg_ints(std::string const& inStr, std::vector<size_t>& outVector)
//
// Parse comma separated list of integers and append to `outSet`.
//
void parse_arg_ints(std::string const& inStr, std::set<size_t>& outSet)
{
// std::cout << inStr << std::endl;
outSet.clear();
for(std::sregex_token_iterator tok{inStr.begin(), inStr.end(), vector_delim, -1};
tok != std::sregex_token_iterator();
++tok)
{
outVector.push_back(std::stoi(tok->str()));
outSet.insert(std::stoi(tok->str()));
}
}

void parse_arg_strings(std::string const& inStr, std::vector<std::string>& outVector)
void parse_arg_strings(std::string const& inStr, std::set<std::string>& outSet)
{
// std::cout << inStr << std::endl;
outSet.clear();
for(std::sregex_token_iterator tok{inStr.begin(), inStr.end(), vector_delim, -1};
tok != std::sregex_token_iterator();
++tok)
{
outVector.push_back(tok->str());
outSet.insert(tok->str());
}
}
14 changes: 10 additions & 4 deletions library/src/device/kernel-generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def __str__(self):
use_3steps_large_twd = getattr(self.function.meta, 'use_3steps_large_twd', None)
if use_3steps_large_twd is not None:
f += ', ' + str(use_3steps_large_twd[self.function.meta.precision])
else:
f += ', false'
factors = getattr(self.function.meta, 'factors', None)
if factors is not None:
f += ', {' + cjoin(factors) + '}'
Expand Down Expand Up @@ -285,9 +287,11 @@ def add(name, scheme, transpose=None):
transpose=transpose)))

for length, scheme in transforms.items():
if scheme == 'CS_KERNEL_STOCKHAM_BLOCK_CC':
if 0:
add(f'rocfft_internal_dfn_{precision}_ci_ci_sbcc_{length}', 'CS_KERNEL_STOCKHAM_BLOCK_CC')
elif scheme == 'CS_KERNEL_STOCKHAM_BLOCK_RC':
# for old-sbcc compatibility: always include the sbcc function (but will be overwritten if new gen has it)
add(f'rocfft_internal_dfn_{precision}_ci_ci_sbcc_{length}', 'CS_KERNEL_STOCKHAM_BLOCK_CC')
add(f'rocfft_internal_dfn_{precision}_op_ci_ci_sbrc_{length}', 'CS_KERNEL_STOCKHAM_BLOCK_RC')
add(f'rocfft_internal_dfn_{precision}_op_ci_ci_sbrc3d_fft_trans_xy_z_tile_aligned_{length}', 'CS_KERNEL_STOCKHAM_TRANSPOSE_XY_Z', 'TILE_ALIGNED')
add(f'rocfft_internal_dfn_{precision}_op_ci_ci_sbrc3d_fft_trans_z_xy_tile_aligned_{length}', 'CS_KERNEL_STOCKHAM_TRANSPOSE_Z_XY', 'TILE_ALIGNED')
Expand Down Expand Up @@ -559,7 +563,7 @@ def list_new_large_kernels():
NS(length=50, factors=[10, 5], use_3steps_large_twd={'sp': 'true', 'dp': 'true'}, threads_per_block=256),
NS(length=64, factors=[8, 8], use_3steps_large_twd={'sp': 'true', 'dp': 'false'}),
NS(length=81, factors=[3, 3, 3, 3], use_3steps_large_twd={'sp': 'true', 'dp': 'true'}),
NS(length=100, factors=[5, 5, 4], use_3steps_large_twd={'sp': 'true', 'dp': 'false'}),
# NS(length=100, factors=[5, 5, 4], use_3steps_large_twd={'sp': 'true', 'dp': 'false'}),
NS(length=128, factors=[8, 4, 4], use_3steps_large_twd={'sp': 'true', 'dp': 'false'}),
NS(length=200, factors=[8, 5, 5], use_3steps_large_twd={'sp': 'false', 'dp': 'false'}),
NS(length=256, factors=[4, 4, 4, 4], use_3steps_large_twd={'sp': 'true', 'dp': 'true'})
Expand Down Expand Up @@ -773,7 +777,7 @@ def cli():
# return the necessary include files to cmake
#
if args.command == 'list':

scprint(set(list_old_generated_kernels(patterns=patterns,
precisions=precisions,
num_small_kernel_groups=args.groups)
Expand Down Expand Up @@ -803,10 +807,12 @@ def cli():

old_small_lengths = {f.meta.length for f in psmall.values()}
old_large_lengths = {f.meta.length for f in plarge.values()} # sbcc=new-gen, sbrc/transpose=old-gen
new_large_lengths = {k.length for k in new_large_kernels} # sbcc by new-gen

if old_small_lengths:
subprocess.run([args.generator, '-g', str(args.groups), '-p', args.precision, '-t', 'none', '--manual-small', cjoin(sorted(old_small_lengths))], check=True)
if old_large_lengths:
subprocess.run([args.generator, '-g', str(args.groups), '-p', args.precision, '-t', 'none', '--manual-large', cjoin(sorted(old_large_lengths))], check=True)
subprocess.run([args.generator, '-g', str(args.groups), '-p', args.precision, '-t', 'none', '--manual-large', cjoin(sorted(old_large_lengths)), '--no-sbcc', cjoin(sorted(new_large_lengths))], check=True)
if dim2:
subprocess.run([args.generator, '-g', str(args.groups), '-p', args.precision, '-t', '2D'], check=True)

Expand Down
46 changes: 33 additions & 13 deletions library/src/plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1007,21 +1007,41 @@ bool TreeNode::use_CS_2D_SINGLE()

bool TreeNode::use_CS_2D_RC()
{
// For CS_2D_RC, we are reusing SBCC kernel for 1D middle size. The
// current implementation of 1D SBCC supports only 64, 128, and 256.
// However, technically no LDS limitation along the fast dimension
// on upper bound for 2D SBCC cases, and even should not limit to pow
// of 2.

// FIXME: use all SBCC kernels instead after we fix the bugs in buffer assignment
// std::set<int> sbcc_support = {50, 64, 81, 100, 128, 200, 256};
std::set<int> sbcc_support = {64, 128, 256};
if((sbcc_support.find(length[1]) != sbcc_support.end()) && (length[0] >= 64))
try
{
// find the sbcc kernel (throws if not found / or old-sbcc without factors / or new-sbcc with factor)
bool oldKernel
= function_pool::get_kernel(fpkey(length[1], precision, CS_KERNEL_STOCKHAM_BLOCK_CC))
.factors.empty();
if(oldKernel)
{
// old-sbcc:
// we are reusing SBCC kernel for 1D middle size. The
// current implementation of 1D SBCC supports only 64, 128, and 256.
// However, technically no LDS limitation along the fast dimension
// on upper bound for 2D SBCC cases, and even should not limit to pow
// of 2.
if(IsPo2(length[1]) && (length[0] >= 64))
{
size_t bwd, wgs, lds;
GetBlockComputeTable(length[1], bwd, wgs, lds);
// need tile-aligned
return (length[0] % bwd == 0);
}
return false;
}
else
{
// new-sbcc supports non-tile-aligned, only check if exceeds the min threshold.
// only 64,128,256 are available due to some buffer assign bug
return (IsPo2(length[1]) && length[0] >= 64);
}
}
catch(...)
{
return true;
// get_kernel throws, sbcc kernel not found in pool
return false;
}

return false;
}

size_t TreeNode::count_3D_SBRC_nodes()
Expand Down
29 changes: 22 additions & 7 deletions library/src/powX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,28 @@ bool PlanPowX(ExecPlan& execPlan)
execPlan.execSeq[0]->precision,
CS_KERNEL_STOCKHAM_BLOCK_CC));
ptr = kernel.device_function;
gp.b_x = ((execPlan.execSeq[i]->length[1]) - 1) / kernel.batches_per_block + 1;
// repeat for higher dimensions + batch
gp.b_x *= std::accumulate(execPlan.execSeq[i]->length.begin() + 2,
execPlan.execSeq[i]->length.end(),
execPlan.execSeq[i]->batch,
std::multiplies<size_t>());
gp.tpb_x = kernel.threads_per_block;

if(kernel.threads_per_block > 0)
{
gp.b_x = ((execPlan.execSeq[i]->length[1]) - 1) / kernel.batches_per_block + 1;
// repeat for higher dimensions + batch
gp.b_x *= std::accumulate(execPlan.execSeq[i]->length.begin() + 2,
execPlan.execSeq[i]->length.end(),
execPlan.execSeq[i]->batch,
std::multiplies<size_t>());
gp.tpb_x = kernel.threads_per_block;
}
else
{
GetBlockComputeTable(execPlan.execSeq[i]->length[0], bwd, wgs, lds);
gp.b_x = (execPlan.execSeq[i]->length[1]) / bwd;
// repeat for higher dimensions + batch
gp.b_x *= std::accumulate(execPlan.execSeq[i]->length.begin() + 2,
execPlan.execSeq[i]->length.end(),
execPlan.execSeq[i]->batch,
std::multiplies<size_t>());
gp.tpb_x = wgs;
}
}
break;
case CS_KERNEL_STOCKHAM_BLOCK_RC:
Expand Down

0 comments on commit b93c40c

Please sign in to comment.