diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 6e11cf93d89..8407d6b3b02 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -2298,6 +2298,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { ArgumentBuilder func_args; // outputs + Val* out0 =output.get(0); + std::cout << "output: " << out0 << ", dtype: " << out0->dtype()<< std::endl; func_args.arg(genVariableName(output.get(0))); func_args.arg(genVariableName(output.get(1))); func_args.arg(genVariableName(output.get(2))); @@ -2792,11 +2794,18 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { if (alloc->alias() != nullptr) { // Allocate alias another Allocate stmt const auto alias_tv = alloc->alias()->buffer()->as(); + + + if (alias_tv->getDataType() == tv->getDataType()) { indent() << "// Alias Allocation - " << alloc->memoryType() << "\n"; indent() << "auto& " << genVariableName(tv) << " = " << genVariableName(alias_tv) << ";\n"; } else { + + std::cout << "alias_tv getDataType - " << alias_tv->getDataType().value() << std::endl; + std::cout << "tv getDataType - " << tv->getDataType().value() << std::endl; + indent() << "// Alias Allocation (changing dtype) - " << alloc->memoryType() << "\n"; indent() << "auto " << genVariableName(tv)