Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
liqiangxl committed Feb 22, 2024
1 parent 95ed23a commit 23aa248
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
// array. This avoid the type mismatch in template functions when one of the
// arguments is an aligned array (Array<T,N>) while the other is a regular
// array T[N].
std::string genVarForTemplateFunction(Val* v) {
std::string genVariableNameConvertAlignedArray(Val* v) {
TensorView* tv = nullptr;
if (v->isA<kir::TensorIndex>()) {
tv = v->as<kir::TensorIndex>()->view();
Expand Down Expand Up @@ -2318,13 +2318,14 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
ArgumentBuilder func_args;

// outputs
func_args.arg(genVarForTemplateFunction(output.get(0)));
func_args.arg(genVarForTemplateFunction(output.get(1)));
func_args.arg(genVarForTemplateFunction(output.get(2)));
func_args.arg(genVariableNameConvertAlignedArray(output.get(0)));
func_args.arg(genVariableNameConvertAlignedArray(output.get(1)));
func_args.arg(genVariableNameConvertAlignedArray(output.get(2)));
// inputs
func_args.arg(genVarForTemplateFunction(input.get(0)));
func_args.arg(genVarForTemplateFunction(input.get(1)));
func_args.arg(genVarForTemplateFunction(input.get(2))).append("[0]");
func_args.arg(genVariableNameConvertAlignedArray(input.get(0)));
func_args.arg(genVariableNameConvertAlignedArray(input.get(1)));
func_args.arg(genVariableNameConvertAlignedArray(input.get(2)))
.append("[0]");

// global buf
for (const auto i : c10::irange(3)) {
Expand Down

0 comments on commit 23aa248

Please sign in to comment.