Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
liqiangxl committed Feb 17, 2024
1 parent 248429c commit 2a8fabd
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
initStringStreamFormat(code_);
}

std::unordered_set<Val*> array_of_regs_;

using kir::ConstIrVisitor::handle;

void initStringStreamFormat(std::stringstream& ss) {
Expand Down Expand Up @@ -2298,9 +2300,19 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
ArgumentBuilder func_args;

// outputs
Val* out0 =output.get(0);
std::cout << "output: " << out0 << ", dtype: " << out0->dtype()<< std::endl;
func_args.arg(genVariableName(output.get(0)));
auto out0 =output.get(0);
if(out0->isA<kir::TensorIndex>()){
std::cout << "TensorIndex: " << out0 << ", tv: " << out0->as<kir::TensorIndex>()->view() << std::endl;
}else if(out0->isA<TensorView>()){
std::cout << "TensorView: " << out0 << ", dtype: " << out0->dtype()<< std::endl;
}else{
std::cout << "??? output: " << out0 << ", dtype: " << out0->dtype()<< std::endl;
}
if(array_of_regs_.count(out0->as<kir::TensorIndex>()->view())){
func_args.arg(genVariableName(output.get(0))).append(".array");
}else{
func_args.arg(genVariableName(output.get(0)));
}
func_args.arg(genVariableName(output.get(1)));
func_args.arg(genVariableName(output.get(2)));
// inputs
Expand Down Expand Up @@ -2812,6 +2824,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
<< " = *reinterpret_cast<Array<" << buffer_dtype << ", "
<< genInline(size) << ">*>(&" << genVariableName(alias_tv)
<< ");\n";
std::cout << "array_of_regs_ insert: " << tv << std::endl;
array_of_regs_.insert(tv);
}
} else {
// Standard Memory Allocation
Expand All @@ -2838,6 +2852,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
indent() << "Array<" << buffer_dtype << ", " << genInline(size)
<< ", " << va.at(tv) << "> " << genVariableName(tv)
<< ";\n";
std::cout << "array_of_regs_ insert: " << tv << std::endl;
array_of_regs_.insert(tv);
} else {
indent() << buffer_dtype << " " << genVariableName(tv) << "["
<< genInline(size) << "];\n";
Expand Down

0 comments on commit 2a8fabd

Please sign in to comment.