From 2a8fabd31a063b30d681772acc0fb71dbce43b4a Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sat, 17 Feb 2024 00:50:32 +0000 Subject: [PATCH] wip --- csrc/codegen.cpp | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 8407d6b3b02..2ae2af5b591 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -175,6 +175,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { initStringStreamFormat(code_); } + std::unordered_set array_of_regs_; + using kir::ConstIrVisitor::handle; void initStringStreamFormat(std::stringstream& ss) { @@ -2298,9 +2300,19 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { ArgumentBuilder func_args; // outputs - Val* out0 =output.get(0); - std::cout << "output: " << out0 << ", dtype: " << out0->dtype()<< std::endl; - func_args.arg(genVariableName(output.get(0))); + auto out0 =output.get(0); + if(out0->isA()){ + std::cout << "TensorIndex: " << out0 << ", tv: " << out0->as()->view() << std::endl; + }else if(out0->isA()){ + std::cout << "TensorView: " << out0 << ", dtype: " << out0->dtype()<< std::endl; + }else{ + std::cout << "??? output: " << out0 << ", dtype: " << out0->dtype()<< std::endl; + } + if(array_of_regs_.count(out0->as()->view())){ + func_args.arg(genVariableName(output.get(0))).append(".array"); + }else{ + func_args.arg(genVariableName(output.get(0))); + } func_args.arg(genVariableName(output.get(1))); func_args.arg(genVariableName(output.get(2))); // inputs @@ -2812,6 +2824,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { << " = *reinterpret_cast*>(&" << genVariableName(alias_tv) << ");\n"; + std::cout << "array_of_regs_ insert: " << tv << std::endl; + array_of_regs_.insert(tv); } } else { // Standard Memory Allocation @@ -2838,6 +2852,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << "Array<" << buffer_dtype << ", " << genInline(size) << ", " << va.at(tv) << "> " << genVariableName(tv) << ";\n"; + std::cout << "array_of_regs_ insert: " << tv << std::endl; + array_of_regs_.insert(tv); } else { indent() << buffer_dtype << " " << genVariableName(tv) << "[" << genInline(size) << "];\n";