From 3240d8bacf13bd4278069476ad97a25f994b7438 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Mon, 30 Mar 2020 20:41:07 -0400 Subject: [PATCH] removed llvm::Value *Stmt::value --- docs/contributor_guide.rst | 1 + taichi/codegen/codegen_llvm.h | 425 +++++++++++++++------------ taichi/codegen/codegen_llvm_cuda.cpp | 102 ++++--- taichi/ir/ir.h | 2 - 4 files changed, 289 insertions(+), 241 deletions(-) diff --git a/docs/contributor_guide.rst b/docs/contributor_guide.rst index 99d5c3da8825d..503ad0d814745 100644 --- a/docs/contributor_guide.rst +++ b/docs/contributor_guide.rst @@ -89,6 +89,7 @@ Existing tags: - ``[Type]``: type system; - ``[Infra]``: general infrastructure, e.g. logging, image reader; - ``[GUI]``: the built-in GUI system; +- ``[Refactor]``: code refactoring; - ``[CLI]``: commandline interfaces, e.g. the ``ti`` command; - ``[Doc]``: documentation; - ``[Example]``: examples under ``taichi/examples/``; diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 0574fe18ed6d9..cc1b43530b4c1 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -2,6 +2,7 @@ #pragma once #include +#include #include "taichi/ir/ir.h" #include "taichi/program/program.h" @@ -27,6 +28,8 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { SNodeAttributes &snode_attr; int task_counter; + std::unordered_map llvm_val; + using IRVisitor::visit; using ModuleBuilder::call; @@ -266,19 +269,19 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { void visit(AllocaStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - stmt->value = create_entry_block_alloca(stmt->ret_type.data_type); + llvm_val[stmt] = create_entry_block_alloca(stmt->ret_type.data_type); // initialize as zero builder->CreateStore(tlctx->get_constant(stmt->ret_type.data_type, 0), - stmt->value); + llvm_val[stmt]); } void visit(RandStmt *stmt) override { - stmt->value = create_call( + llvm_val[stmt] = create_call( fmt::format("rand_{}", data_type_short_name(stmt->ret_type.data_type))); } virtual void emit_extra_unary(UnaryOpStmt *stmt) { - auto input = stmt->operand->value; + auto input = llvm_val[stmt->operand]; auto input_taichi_type = stmt->operand->ret_type.data_type; auto op = stmt->op_type; auto input_type = input->getType(); @@ -286,13 +289,13 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { #define UNARY_STD(x) \ else if (op == UnaryOpType::x) { \ if (input_taichi_type == DataType::f32) { \ - stmt->value = \ + llvm_val[stmt] = \ builder->CreateCall(get_runtime_function(#x "_f32"), input); \ } else if (input_taichi_type == DataType::f64) { \ - stmt->value = \ + llvm_val[stmt] = \ builder->CreateCall(get_runtime_function(#x "_f64"), input); \ } else if (input_taichi_type == DataType::i32) { \ - stmt->value = \ + llvm_val[stmt] = \ builder->CreateCall(get_runtime_function(#x "_i32"), input); \ } else { \ TI_NOT_IMPLEMENTED \ @@ -312,8 +315,8 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { UNARY_STD(cos) UNARY_STD(sin) else if (op == UnaryOpType::sqrt) { - stmt->value = builder->CreateIntrinsic(llvm::Intrinsic::sqrt, - {input_type}, {input}); + llvm_val[stmt] = builder->CreateIntrinsic(llvm::Intrinsic::sqrt, + {input_type}, {input}); } else { TI_P(unary_op_type_name(op)); @@ -323,13 +326,13 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { } void visit(UnaryOpStmt *stmt) override { - auto input = stmt->operand->value; + auto input = llvm_val[stmt->operand]; auto input_type = input->getType(); auto op = stmt->op_type; #define UNARY_INTRINSIC(x) \ else if (op == UnaryOpType::x) { \ - stmt->value = \ + llvm_val[stmt] = \ builder->CreateIntrinsic(llvm::Intrinsic::x, {input_type}, {input}); \ } @@ -338,15 +341,15 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { llvm::Function *sqrt_fn = Intrinsic::getDeclaration( module.get(), Intrinsic::sqrt, input->getType()); auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt"); - stmt->value = builder->CreateFDiv( + llvm_val[stmt] = builder->CreateFDiv( tlctx->get_constant(stmt->ret_type.data_type, 1.0), intermediate); } else if (op == UnaryOpType::bit_not) { - stmt->value = builder->CreateNot(input); + llvm_val[stmt] = builder->CreateNot(input); } else if (op == UnaryOpType::neg) { if (is_real(stmt->operand->ret_type.data_type)) { - stmt->value = builder->CreateFNeg(input, "neg"); + llvm_val[stmt] = builder->CreateFNeg(input, "neg"); } else { - stmt->value = builder->CreateNeg(input, "neg"); + llvm_val[stmt] = builder->CreateNeg(input, "neg"); } } UNARY_INTRINSIC(floor) @@ -370,31 +373,31 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { TI_P(data_type_name(to)); TI_NOT_IMPLEMENTED; } - stmt->value = - builder->CreateCast(cast_op, stmt->operand->value, + llvm_val[stmt] = + builder->CreateCast(cast_op, llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); } else if (is_real(from) && is_real(to)) { if (data_type_size(from) < data_type_size(to)) { - stmt->value = builder->CreateFPExt( - stmt->operand->value, tlctx->get_data_type(stmt->cast_type)); + llvm_val[stmt] = builder->CreateFPExt( + llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); } else { - stmt->value = builder->CreateFPTrunc( - stmt->operand->value, tlctx->get_data_type(stmt->cast_type)); + llvm_val[stmt] = builder->CreateFPTrunc( + llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); } } else if (!is_real(from) && !is_real(to)) { if (data_type_size(from) < data_type_size(to)) { - stmt->value = builder->CreateSExt( - stmt->operand->value, tlctx->get_data_type(stmt->cast_type)); + llvm_val[stmt] = builder->CreateSExt( + llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); } else { - stmt->value = builder->CreateTrunc( - stmt->operand->value, tlctx->get_data_type(stmt->cast_type)); + llvm_val[stmt] = builder->CreateTrunc( + llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); } } } else { TI_ASSERT(data_type_size(stmt->ret_type.data_type) == data_type_size(stmt->cast_type)); - stmt->value = builder->CreateBitCast( - stmt->operand->value, tlctx->get_data_type(stmt->cast_type)); + llvm_val[stmt] = builder->CreateBitCast( + llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); } } } @@ -419,52 +422,66 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { auto ret_type = stmt->ret_type.data_type; if (op == BinaryOpType::add) { if (is_real(stmt->ret_type.data_type)) { - stmt->value = builder->CreateFAdd(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateFAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { - stmt->value = builder->CreateAdd(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::sub) { if (is_real(stmt->ret_type.data_type)) { - stmt->value = builder->CreateFSub(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateFSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { - stmt->value = builder->CreateSub(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::mul) { if (is_real(stmt->ret_type.data_type)) { - stmt->value = builder->CreateFMul(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateFMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { - stmt->value = builder->CreateMul(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::floordiv) { if (is_integral(ret_type)) - stmt->value = create_call( + llvm_val[stmt] = create_call( fmt::format("floordiv_{}", data_type_short_name(ret_type)), - {stmt->lhs->value, stmt->rhs->value}); + {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); else { - auto div = builder->CreateFDiv(stmt->lhs->value, stmt->rhs->value); - stmt->value = builder->CreateIntrinsic( + auto div = + builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); + llvm_val[stmt] = builder->CreateIntrinsic( llvm::Intrinsic::floor, {tlctx->get_data_type(ret_type)}, {div}); } } else if (op == BinaryOpType::div) { if (is_real(stmt->ret_type.data_type)) { - stmt->value = builder->CreateFDiv(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { - stmt->value = builder->CreateSDiv(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateSDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::mod) { - stmt->value = builder->CreateSRem(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateSRem(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else if (op == BinaryOpType::bit_and) { - stmt->value = builder->CreateAnd(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else if (op == BinaryOpType::bit_or) { - stmt->value = builder->CreateOr(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateOr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else if (op == BinaryOpType::bit_xor) { - stmt->value = builder->CreateXor(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateXor(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else if (op == BinaryOpType::max) { if (is_real(ret_type)) { - stmt->value = builder->CreateMaxNum(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateMaxNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else if (ret_type == DataType::i32) { - stmt->value = - create_call("max_i32", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = + create_call("max_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { TI_P(data_type_name(ret_type)); TI_NOT_IMPLEMENTED @@ -472,22 +489,22 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { } else if (op == BinaryOpType::atan2) { if (arch_is_cpu(current_arch())) { if (ret_type == DataType::f32) { - stmt->value = - create_call("atan2_f32", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "atan2_f32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else if (ret_type == DataType::f64) { - stmt->value = - create_call("atan2_f64", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "atan2_f64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { TI_P(data_type_name(ret_type)); TI_NOT_IMPLEMENTED } } else if (current_arch() == Arch::cuda) { if (ret_type == DataType::f32) { - stmt->value = - create_call("__nv_atan2f", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "__nv_atan2f", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else if (ret_type == DataType::f64) { - stmt->value = - create_call("__nv_atan2", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "__nv_atan2", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { TI_P(data_type_name(ret_type)); TI_NOT_IMPLEMENTED @@ -498,34 +515,34 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { } else if (op == BinaryOpType::pow) { if (arch_is_cpu(current_arch())) { if (ret_type == DataType::f32) { - stmt->value = - create_call("pow_f32", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "pow_f32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else if (ret_type == DataType::f64) { - stmt->value = - create_call("pow_f64", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "pow_f64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else if (ret_type == DataType::i32) { - stmt->value = - create_call("pow_i32", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "pow_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else if (ret_type == DataType::i64) { - stmt->value = - create_call("pow_i64", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "pow_i64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { TI_P(data_type_name(ret_type)); TI_NOT_IMPLEMENTED } } else if (current_arch() == Arch::cuda) { if (ret_type == DataType::f32) { - stmt->value = - create_call("__nv_powf", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "__nv_powf", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else if (ret_type == DataType::f64) { - stmt->value = - create_call("__nv_pow", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "__nv_pow", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else if (ret_type == DataType::i32) { - stmt->value = - create_call("pow_i32", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "pow_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else if (ret_type == DataType::i64) { - stmt->value = - create_call("pow_i64", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = create_call( + "pow_i64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { TI_P(data_type_name(ret_type)); TI_NOT_IMPLEMENTED @@ -535,10 +552,11 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { } } else if (op == BinaryOpType::min) { if (is_real(ret_type)) { - stmt->value = builder->CreateMinNum(stmt->lhs->value, stmt->rhs->value); + llvm_val[stmt] = + builder->CreateMinNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else if (ret_type == DataType::i32) { - stmt->value = - create_call("min_i32", {stmt->lhs->value, stmt->rhs->value}); + llvm_val[stmt] = + create_call("min_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { TI_P(data_type_name(ret_type)); TI_NOT_IMPLEMENTED @@ -548,60 +566,74 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { auto input_type = stmt->lhs->ret_type.data_type; if (op == BinaryOpType::cmp_eq) { if (is_real(input_type)) { - cmp = builder->CreateFCmpOEQ(stmt->lhs->value, stmt->rhs->value); + cmp = + builder->CreateFCmpOEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { - cmp = builder->CreateICmpEQ(stmt->lhs->value, stmt->rhs->value); + cmp = builder->CreateICmpEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::cmp_le) { if (is_real(input_type)) { - cmp = builder->CreateFCmpOLE(stmt->lhs->value, stmt->rhs->value); + cmp = + builder->CreateFCmpOLE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { - cmp = builder->CreateICmpSLE(stmt->lhs->value, stmt->rhs->value); + cmp = builder->CreateICmpSLE(llvm_val[stmt->lhs], + llvm_val[stmt->rhs]); } else { - cmp = builder->CreateICmpULE(stmt->lhs->value, stmt->rhs->value); + cmp = builder->CreateICmpULE(llvm_val[stmt->lhs], + llvm_val[stmt->rhs]); } } } else if (op == BinaryOpType::cmp_ge) { if (is_real(input_type)) { - cmp = builder->CreateFCmpOGE(stmt->lhs->value, stmt->rhs->value); + cmp = + builder->CreateFCmpOGE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { - cmp = builder->CreateICmpSGE(stmt->lhs->value, stmt->rhs->value); + cmp = builder->CreateICmpSGE(llvm_val[stmt->lhs], + llvm_val[stmt->rhs]); } else { - cmp = builder->CreateICmpUGE(stmt->lhs->value, stmt->rhs->value); + cmp = builder->CreateICmpUGE(llvm_val[stmt->lhs], + llvm_val[stmt->rhs]); } } } else if (op == BinaryOpType::cmp_lt) { if (is_real(input_type)) { - cmp = builder->CreateFCmpOLT(stmt->lhs->value, stmt->rhs->value); + cmp = + builder->CreateFCmpOLT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { - cmp = builder->CreateICmpSLT(stmt->lhs->value, stmt->rhs->value); + cmp = builder->CreateICmpSLT(llvm_val[stmt->lhs], + llvm_val[stmt->rhs]); } else { - cmp = builder->CreateICmpULT(stmt->lhs->value, stmt->rhs->value); + cmp = builder->CreateICmpULT(llvm_val[stmt->lhs], + llvm_val[stmt->rhs]); } } } else if (op == BinaryOpType::cmp_gt) { if (is_real(input_type)) { - cmp = builder->CreateFCmpOGT(stmt->lhs->value, stmt->rhs->value); + cmp = + builder->CreateFCmpOGT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { - cmp = builder->CreateICmpSGT(stmt->lhs->value, stmt->rhs->value); + cmp = builder->CreateICmpSGT(llvm_val[stmt->lhs], + llvm_val[stmt->rhs]); } else { - cmp = builder->CreateICmpUGT(stmt->lhs->value, stmt->rhs->value); + cmp = builder->CreateICmpUGT(llvm_val[stmt->lhs], + llvm_val[stmt->rhs]); } } } else if (op == BinaryOpType::cmp_ne) { if (is_real(input_type)) { - cmp = builder->CreateFCmpONE(stmt->lhs->value, stmt->rhs->value); + cmp = + builder->CreateFCmpONE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { - cmp = builder->CreateICmpNE(stmt->lhs->value, stmt->rhs->value); + cmp = builder->CreateICmpNE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else { TI_NOT_IMPLEMENTED } - stmt->value = builder->CreateSExt(cmp, llvm_type(DataType::i32)); + llvm_val[stmt] = builder->CreateSExt(cmp, llvm_type(DataType::i32)); } else { TI_P(binary_op_type_name(op)); TI_NOT_IMPLEMENTED @@ -610,9 +642,9 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { void visit(TernaryOpStmt *stmt) override { TI_ASSERT(stmt->op_type == TernaryOpType::select); - stmt->value = builder->CreateSelect( - builder->CreateTrunc(stmt->op1->value, llvm_type(DataType::u1)), - stmt->op2->value, stmt->op3->value); + llvm_val[stmt] = builder->CreateSelect( + builder->CreateTrunc(llvm_val[stmt->op1], llvm_type(DataType::u1)), + llvm_val[stmt->op2], llvm_val[stmt->op3]); } void visit(IfStmt *if_stmt) override { @@ -623,7 +655,7 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { BasicBlock::Create(*llvm_context, "false_block", func); BasicBlock *after_if = BasicBlock::Create(*llvm_context, "after_if", func); builder->CreateCondBr( - builder->CreateICmpNE(if_stmt->cond->value, tlctx->get_constant(0)), + builder->CreateICmpNE(llvm_val[if_stmt->cond], tlctx->get_constant(0)), true_block, false_block); builder->SetInsertPoint(true_block); if (if_stmt->true_statements) { @@ -671,7 +703,7 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { TI_ASSERT(stmt->width() == 1); std::vector args; std::string format; - auto value = stmt->stmt->value; + auto value = llvm_val[stmt->stmt]; auto dt = stmt->stmt->ret_type.data_type; if (dt == DataType::i32) { format = "%d"; @@ -695,23 +727,23 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { "format_string")); args.push_back(value); - stmt->value = builder->CreateCall(runtime_printf, args); + llvm_val[stmt] = builder->CreateCall(runtime_printf, args); } void visit(ConstStmt *stmt) override { TI_ASSERT(stmt->width() == 1); auto val = stmt->val[0]; if (val.dt == DataType::f32) { - stmt->value = llvm::ConstantFP::get(*llvm_context, - llvm::APFloat(val.val_float32())); + llvm_val[stmt] = llvm::ConstantFP::get(*llvm_context, + llvm::APFloat(val.val_float32())); } else if (val.dt == DataType::f64) { - stmt->value = llvm::ConstantFP::get(*llvm_context, - llvm::APFloat(val.val_float64())); + llvm_val[stmt] = llvm::ConstantFP::get(*llvm_context, + llvm::APFloat(val.val_float64())); } else if (val.dt == DataType::i32) { - stmt->value = llvm::ConstantInt::get( + llvm_val[stmt] = llvm::ConstantInt::get( *llvm_context, llvm::APInt(32, val.val_int32(), true)); } else if (val.dt == DataType::i64) { - stmt->value = llvm::ConstantInt::get( + llvm_val[stmt] = llvm::ConstantInt::get( *llvm_context, llvm::APInt(64, val.val_int64(), true)); } else { TI_NOT_IMPLEMENTED; @@ -723,7 +755,7 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { BasicBlock::Create(*llvm_context, "after_break", func); TI_ASSERT(while_after_loop); auto cond = - builder->CreateICmpEQ(stmt->cond->value, tlctx->get_constant(0)); + builder->CreateICmpEQ(llvm_val[stmt->cond], tlctx->get_constant(0)); builder->CreateCondBr(cond, while_after_loop, after_break); builder->SetInsertPoint(after_break); } @@ -800,11 +832,12 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "block", func); BasicBlock *test = BasicBlock::Create(*llvm_context, "test", func); if (!for_stmt->reversed) { - builder->CreateStore(for_stmt->begin->value, for_stmt->loop_var->value); + builder->CreateStore(llvm_val[for_stmt->begin], + llvm_val[for_stmt->loop_var]); } else { builder->CreateStore( - builder->CreateSub(for_stmt->end->value, tlctx->get_constant(1)), - for_stmt->loop_var->value); + builder->CreateSub(llvm_val[for_stmt->end], tlctx->get_constant(1)), + llvm_val[for_stmt->loop_var]); } builder->CreateBr(test); @@ -813,15 +846,15 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { builder->SetInsertPoint(test); llvm::Value *cond; if (!for_stmt->reversed) { - cond = - builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SLT, - builder->CreateLoad(for_stmt->loop_var->value), - for_stmt->end->value); + cond = builder->CreateICmp( + llvm::CmpInst::Predicate::ICMP_SLT, + builder->CreateLoad(llvm_val[for_stmt->loop_var]), + llvm_val[for_stmt->end]); } else { - cond = - builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SGE, - builder->CreateLoad(for_stmt->loop_var->value), - for_stmt->begin->value); + cond = builder->CreateICmp( + llvm::CmpInst::Predicate::ICMP_SGE, + builder->CreateLoad(llvm_val[for_stmt->loop_var]), + llvm_val[for_stmt->begin]); } builder->CreateCondBr(cond, body, after_loop); } @@ -833,9 +866,9 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { for_stmt->body->accept(this); if (!for_stmt->reversed) { - create_increment(for_stmt->loop_var->value, tlctx->get_constant(1)); + create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(1)); } else { - create_increment(for_stmt->loop_var->value, tlctx->get_constant(-1)); + create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(-1)); } builder->CreateBr(test); } @@ -855,13 +888,13 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { llvm::Type *dest_ty = nullptr; if (stmt->is_ptr) { dest_ty = PointerType::get(tlctx->get_data_type(DataType::i32), 0); - stmt->value = builder->CreateIntToPtr(raw_arg, dest_ty); + llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty); } else { dest_ty = tlctx->get_data_type(stmt->ret_type.data_type); auto dest_bits = dest_ty->getPrimitiveSizeInBits(); auto truncated = builder->CreateTrunc( raw_arg, Type::getIntNTy(*llvm_context, dest_bits)); - stmt->value = builder->CreateBitCast(truncated, dest_ty); + llvm_val[stmt] = builder->CreateBitCast(truncated, dest_ty); } } @@ -876,7 +909,8 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { llvm::Type::getIntNTy(*llvm_context, intermediate_bits); llvm::Type *dest_ty = tlctx->get_data_type(); auto extended = builder->CreateZExt( - builder->CreateBitCast(stmt->val->value, intermediate_type), dest_ty); + builder->CreateBitCast(llvm_val[stmt->val], intermediate_type), + dest_ty); // TODO: refactor this part if (get_current_program().config.arch == Arch::cuda && !get_current_program().config.use_unified_memory) { @@ -894,7 +928,7 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { void visit(LocalLoadStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - stmt->value = builder->CreateLoad(stmt->ptr[0].var->value); + llvm_val[stmt] = builder->CreateLoad(llvm_val[stmt->ptr[0].var]); } void visit(LocalStoreStmt *stmt) override { @@ -902,24 +936,25 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { if (mask && stmt->width() != 1) { TI_NOT_IMPLEMENTED } else { - builder->CreateStore(stmt->data->value, stmt->ptr->value); + builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); } } void visit(AssertStmt *stmt) override { if (stmt->args.empty()) { - stmt->value = call("taichi_assert", get_context(), stmt->cond->value, - builder->CreateGlobalStringPtr(stmt->text)); + llvm_val[stmt] = + call("taichi_assert", get_context(), llvm_val[stmt->cond], + builder->CreateGlobalStringPtr(stmt->text)); } else { std::vector args; args.emplace_back(get_runtime()); - args.emplace_back(stmt->cond->value); + args.emplace_back(llvm_val[stmt->cond]); args.emplace_back(builder->CreateGlobalStringPtr(stmt->text)); for (auto arg : stmt->args) { - TI_ASSERT(arg->value); - args.emplace_back(arg->value); + TI_ASSERT(llvm_val[arg]); + args.emplace_back(llvm_val[arg]); } - stmt->value = create_call("taichi_assert_format", args); + llvm_val[stmt] = create_call("taichi_assert_format", args); } } @@ -928,19 +963,20 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { if (stmt->op_type == SNodeOpType::append) { TI_ASSERT(snode->type == SNodeType::dynamic); TI_ASSERT(stmt->ret_type.data_type == DataType::i32); - stmt->value = call(snode, stmt->ptr->value, "append", {stmt->val->value}); + llvm_val[stmt] = + call(snode, llvm_val[stmt->ptr], "append", {llvm_val[stmt->val]}); } else if (stmt->op_type == SNodeOpType::length) { TI_ASSERT(snode->type == SNodeType::dynamic); - stmt->value = call(snode, stmt->ptr->value, "get_num_elements", {}); + llvm_val[stmt] = call(snode, llvm_val[stmt->ptr], "get_num_elements", {}); } else if (stmt->op_type == SNodeOpType::is_active) { - stmt->value = - call(snode, stmt->ptr->value, "is_active", {stmt->val->value}); + llvm_val[stmt] = + call(snode, llvm_val[stmt->ptr], "is_active", {llvm_val[stmt->val]}); } else if (stmt->op_type == SNodeOpType::deactivate) { if (snode->type == SNodeType::pointer || snode->type == SNodeType::hash) { - stmt->value = - call(snode, stmt->ptr->value, "deactivate", {stmt->val->value}); + llvm_val[stmt] = call(snode, llvm_val[stmt->ptr], "deactivate", + {llvm_val[stmt->val]}); } else if (snode->type == SNodeType::dynamic) { - stmt->value = call(snode, stmt->ptr->value, "deactivate", {}); + llvm_val[stmt] = call(snode, llvm_val[stmt->ptr], "deactivate", {}); } } else { TI_NOT_IMPLEMENTED @@ -956,79 +992,85 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { if (stmt->op_type == AtomicOpType::add) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::Add, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::Add, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else if (stmt->val->ret_type.data_type == DataType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_add_f32"), - {stmt->dest->value, stmt->val->value}); + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else if (stmt->val->ret_type.data_type == DataType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_add_f64"), - {stmt->dest->value, stmt->val->value}); + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::min) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::Min, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else if (stmt->val->ret_type.data_type == DataType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f32"), - {stmt->dest->value, stmt->val->value}); + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else if (stmt->val->ret_type.data_type == DataType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f64"), - {stmt->dest->value, stmt->val->value}); + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::max) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::Max, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else if (stmt->val->ret_type.data_type == DataType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f32"), - {stmt->dest->value, stmt->val->value}); + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else if (stmt->val->ret_type.data_type == DataType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f64"), - {stmt->dest->value, stmt->val->value}); + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_and) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::And, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::And, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_or) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::Or, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::Or, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_xor) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::Xor, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::Xor, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else { TI_NOT_IMPLEMENTED } } else { TI_NOT_IMPLEMENTED } - stmt->value = old_value; + llvm_val[stmt] = old_value; } } @@ -1038,16 +1080,16 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { void visit(GlobalStoreStmt *stmt) override { TI_ASSERT(!stmt->parent->mask() || stmt->width() == 1); - TI_ASSERT(stmt->data->value); - TI_ASSERT(stmt->ptr->value); - builder->CreateStore(stmt->data->value, stmt->ptr->value); + TI_ASSERT(llvm_val[stmt->data]); + TI_ASSERT(llvm_val[stmt->ptr]); + builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); } void visit(GlobalLoadStmt *stmt) override { int width = stmt->width(); TI_ASSERT(width == 1); - stmt->value = builder->CreateLoad( - tlctx->get_data_type(stmt->ret_type.data_type), stmt->ptr->value); + llvm_val[stmt] = builder->CreateLoad( + tlctx->get_data_type(stmt->ret_type.data_type), llvm_val[stmt->ptr]); } void visit(ElementShuffleStmt *stmt) override { @@ -1105,16 +1147,16 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { } void visit(GetRootStmt *stmt) override { - stmt->value = builder->CreateBitCast( + llvm_val[stmt] = builder->CreateBitCast( get_root(), PointerType::get(snode_attr[prog->snode_root.get()].llvm_type, 0)); } void visit(OffsetAndExtractBitsStmt *stmt) override { - auto shifted = builder->CreateAdd(stmt->input->value, + auto shifted = builder->CreateAdd(llvm_val[stmt->input], tlctx->get_constant((int32)stmt->offset)); int mask = (1u << (stmt->bit_end - stmt->bit_begin)) - 1; - stmt->value = + llvm_val[stmt] = builder->CreateAnd(builder->CreateLShr(shifted, stmt->bit_begin), tlctx->get_constant(mask)); } @@ -1124,9 +1166,9 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { for (int i = 0; i < (int)stmt->inputs.size(); i++) { val = builder->CreateAdd( builder->CreateMul(val, tlctx->get_constant(stmt->strides[i])), - stmt->inputs[i]->value); + llvm_val[stmt->inputs[i]]); } - stmt->value = val; + llvm_val[stmt] = val; } void visit(IntegerOffsetStmt *stmt) override { @@ -1145,21 +1187,21 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { void visit(SNodeLookupStmt *stmt) override { llvm::Value *parent = nullptr; - parent = stmt->input_snode->value; + parent = llvm_val[stmt->input_snode]; TI_ASSERT(parent); auto snode = stmt->snode; if (snode->type == SNodeType::root) { - stmt->value = builder->CreateGEP(parent, stmt->input_index->value); + llvm_val[stmt] = builder->CreateGEP(parent, llvm_val[stmt->input_index]); } else if (snode->type == SNodeType::dense || snode->type == SNodeType::pointer || snode->type == SNodeType::dynamic || snode->type == SNodeType::bitmasked) { if (stmt->activate) { - call(snode, stmt->input_snode->value, "activate", - {stmt->input_index->value}); + call(snode, llvm_val[stmt->input_snode], "activate", + {llvm_val[stmt->input_index]}); } - stmt->value = call(snode, stmt->input_snode->value, "lookup_element", - {stmt->input_index->value}); + llvm_val[stmt] = call(snode, llvm_val[stmt->input_snode], + "lookup_element", {llvm_val[stmt->input_index]}); } else { TI_INFO(snode_type_name(snode->type)); TI_NOT_IMPLEMENTED @@ -1169,9 +1211,9 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { void visit(GetChStmt *stmt) override { auto ch = create_call( stmt->output_snode->get_ch_from_parent_func_name(), - {builder->CreateBitCast(stmt->input_ptr->value, + {builder->CreateBitCast(llvm_val[stmt->input_ptr], PointerType::getInt8PtrTy(*llvm_context))}); - stmt->value = builder->CreateBitCast( + llvm_val[stmt] = builder->CreateBitCast( ch, PointerType::get(snode_attr[stmt->output_snode].llvm_type, 0)); } @@ -1192,16 +1234,17 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { auto dt = stmt->ret_type.data_type; auto base = builder->CreateBitCast( - stmt->base_ptrs[0]->value, + llvm_val[stmt->base_ptrs[0]], llvm::PointerType::get(tlctx->get_data_type(dt), 0)); auto linear_index = tlctx->get_constant(0); for (int i = 0; i < num_indices; i++) { linear_index = builder->CreateMul(linear_index, sizes[i]); - linear_index = builder->CreateAdd(linear_index, stmt->indices[i]->value); + linear_index = + builder->CreateAdd(linear_index, llvm_val[stmt->indices[i]]); } - stmt->value = builder->CreateGEP(base, linear_index); + llvm_val[stmt] = builder->CreateGEP(base, linear_index); } BasicBlock *func_body_bb; @@ -1323,7 +1366,7 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { auto begin_stmt = Stmt::make( stmt->begin_offset, VectorType(1, DataType::i32)); begin_stmt->accept(this); - begin = builder->CreateLoad(begin_stmt->value); + begin = builder->CreateLoad(llvm_val[begin_stmt.get()]); } if (stmt->const_end) { end = tlctx->get_constant(stmt->end_value); @@ -1331,7 +1374,7 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { auto end_stmt = Stmt::make( stmt->end_offset, VectorType(1, DataType::i32)); end_stmt->accept(this); - end = builder->CreateLoad(end_stmt->value); + end = builder->CreateLoad(llvm_val[end_stmt.get()]); } return std::tuple(begin, end); } @@ -1495,11 +1538,11 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { void visit(LoopIndexStmt *stmt) override { if (stmt->is_struct_for) { - stmt->value = builder->CreateLoad(builder->CreateGEP( + llvm_val[stmt] = builder->CreateLoad(builder->CreateGEP( current_coordinates, {tlctx->get_constant(0), tlctx->get_constant(0), tlctx->get_constant(stmt->index)})); } else { - stmt->value = builder->CreateLoad( + llvm_val[stmt] = builder->CreateLoad( current_offloaded_stmt->loop_vars_llvm[stmt->index]); } } @@ -1512,7 +1555,7 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { TI_ASSERT(stmt->width() == 1); auto ptr_type = llvm::PointerType::get( tlctx->get_data_type(stmt->ret_type.data_type), 0); - stmt->value = builder->CreatePointerCast(buffer, ptr_type); + llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type); } void visit(InternalFuncStmt *stmt) override { @@ -1524,58 +1567,58 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*llvm_context), stmt->size_in_bytes()); auto alloca = create_entry_block_alloca(type, sizeof(int64)); - stmt->value = builder->CreateBitCast( + llvm_val[stmt] = builder->CreateBitCast( alloca, llvm::PointerType::getInt8PtrTy(*llvm_context)); - call("stack_init", stmt->value); + call("stack_init", llvm_val[stmt]); } void visit(StackPopStmt *stmt) override { - call("stack_pop", stmt->stack->value); + call("stack_pop", llvm_val[stmt->stack]); } void visit(StackPushStmt *stmt) override { auto stack = stmt->stack->as(); - call("stack_push", stack->value, tlctx->get_constant(stack->max_size), + call("stack_push", llvm_val[stack], tlctx->get_constant(stack->max_size), tlctx->get_constant(stack->element_size_in_bytes())); - auto primal_ptr = call("stack_top_primal", stack->value, + auto primal_ptr = call("stack_top_primal", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); primal_ptr = builder->CreateBitCast( primal_ptr, llvm::PointerType::get( tlctx->get_data_type(stmt->ret_type.data_type), 0)); - builder->CreateStore(stmt->v->value, primal_ptr); + builder->CreateStore(llvm_val[stmt->v], primal_ptr); } void visit(StackLoadTopStmt *stmt) override { auto stack = stmt->stack->as(); - auto primal_ptr = call("stack_top_primal", stack->value, + auto primal_ptr = call("stack_top_primal", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); primal_ptr = builder->CreateBitCast( primal_ptr, llvm::PointerType::get( tlctx->get_data_type(stmt->ret_type.data_type), 0)); - stmt->value = builder->CreateLoad(primal_ptr); + llvm_val[stmt] = builder->CreateLoad(primal_ptr); } void visit(StackLoadTopAdjStmt *stmt) override { auto stack = stmt->stack->as(); - auto adjoint = call("stack_top_adjoint", stack->value, + auto adjoint = call("stack_top_adjoint", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); adjoint = builder->CreateBitCast( adjoint, llvm::PointerType::get( tlctx->get_data_type(stmt->ret_type.data_type), 0)); - stmt->value = builder->CreateLoad(adjoint); + llvm_val[stmt] = builder->CreateLoad(adjoint); } void visit(StackAccAdjointStmt *stmt) override { auto stack = stmt->stack->as(); auto adjoint_ptr = - call("stack_top_adjoint", stack->value, + call("stack_top_adjoint", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); adjoint_ptr = builder->CreateBitCast( adjoint_ptr, llvm::PointerType::get( tlctx->get_data_type(stack->ret_type.data_type), 0)); auto old_val = builder->CreateLoad(adjoint_ptr); TI_ASSERT(is_real(stmt->v->ret_type.data_type)); - auto new_val = builder->CreateFAdd(old_val, stmt->v->value); + auto new_val = builder->CreateFAdd(old_val, llvm_val[stmt->v]); builder->CreateStore(new_val, adjoint_ptr); } diff --git a/taichi/codegen/codegen_llvm_cuda.cpp b/taichi/codegen/codegen_llvm_cuda.cpp index 9e422f14d8cee..c3570164a8864 100644 --- a/taichi/codegen/codegen_llvm_cuda.cpp +++ b/taichi/codegen/codegen_llvm_cuda.cpp @@ -114,7 +114,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { std::string format; - auto value = stmt->stmt->value; + auto value = llvm_val[stmt->stmt]; if (stmt->stmt->ret_type.data_type == DataType::i32) { format = "%d"; @@ -139,7 +139,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { auto format_str = "[debug] " + stmt->str + " = " + format + "\n"; - stmt->value = ModuleBuilder::call( + llvm_val[stmt] = ModuleBuilder::call( builder.get(), "vprintf", builder->CreateGlobalStringPtr(format_str, "format_string"), builder->CreateBitCast(values, @@ -148,50 +148,50 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { void emit_extra_unary(UnaryOpStmt *stmt) override { // functions from libdevice - auto input = stmt->operand->value; + auto input = llvm_val[stmt->operand]; auto input_taichi_type = stmt->operand->ret_type.data_type; auto op = stmt->op_type; -#define UNARY_STD(x) \ - else if (op == UnaryOpType::x) { \ - if (input_taichi_type == DataType::f32) { \ - stmt->value = \ - builder->CreateCall(get_runtime_function("__nv_" #x "f"), input); \ - } else if (input_taichi_type == DataType::f64) { \ - stmt->value = \ - builder->CreateCall(get_runtime_function("__nv_" #x), input); \ - } else if (input_taichi_type == DataType::i32) { \ - stmt->value = builder->CreateCall(get_runtime_function(#x), input); \ - } else { \ - TI_NOT_IMPLEMENTED \ - } \ +#define UNARY_STD(x) \ + else if (op == UnaryOpType::x) { \ + if (input_taichi_type == DataType::f32) { \ + llvm_val[stmt] = \ + builder->CreateCall(get_runtime_function("__nv_" #x "f"), input); \ + } else if (input_taichi_type == DataType::f64) { \ + llvm_val[stmt] = \ + builder->CreateCall(get_runtime_function("__nv_" #x), input); \ + } else if (input_taichi_type == DataType::i32) { \ + llvm_val[stmt] = builder->CreateCall(get_runtime_function(#x), input); \ + } else { \ + TI_NOT_IMPLEMENTED \ + } \ } if (op == UnaryOpType::abs) { if (input_taichi_type == DataType::f32) { - stmt->value = + llvm_val[stmt] = builder->CreateCall(get_runtime_function("__nv_fabsf"), input); } else if (input_taichi_type == DataType::f64) { - stmt->value = + llvm_val[stmt] = builder->CreateCall(get_runtime_function("__nv_fabs"), input); } else if (input_taichi_type == DataType::i32) { - stmt->value = + llvm_val[stmt] = builder->CreateCall(get_runtime_function("__nv_abs"), input); } else { TI_NOT_IMPLEMENTED } } else if (op == UnaryOpType::sqrt) { if (input_taichi_type == DataType::f32) { - stmt->value = + llvm_val[stmt] = builder->CreateCall(get_runtime_function("__nv_sqrtf"), input); } else if (input_taichi_type == DataType::f64) { - stmt->value = + llvm_val[stmt] = builder->CreateCall(get_runtime_function("__nv_sqrt"), input); } else { TI_NOT_IMPLEMENTED } } else if (op == UnaryOpType::logic_not) { if (input_taichi_type == DataType::i32) { - stmt->value = + llvm_val[stmt] = builder->CreateCall(get_runtime_function("logic_not_i32"), input); } else { TI_NOT_IMPLEMENTED @@ -225,88 +225,94 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { if (stmt->op_type == AtomicOpType::add) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::Add, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::Add, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else if (stmt->val->ret_type.data_type == DataType::f32) { auto dt = tlctx->get_data_type(DataType::f32); - old_value = - builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f32, - {llvm::PointerType::get(dt, 0)}, - {stmt->dest->value, stmt->val->value}); + old_value = builder->CreateIntrinsic( + Intrinsic::nvvm_atomic_load_add_f32, + {llvm::PointerType::get(dt, 0)}, + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else if (stmt->val->ret_type.data_type == DataType::f64) { auto dt = tlctx->get_data_type(DataType::f64); - old_value = - builder->CreateIntrinsic(Intrinsic::nvvm_atomic_load_add_f64, - {llvm::PointerType::get(dt, 0)}, - {stmt->dest->value, stmt->val->value}); + old_value = builder->CreateIntrinsic( + Intrinsic::nvvm_atomic_load_add_f64, + {llvm::PointerType::get(dt, 0)}, + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::min) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::Min, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else if (stmt->val->ret_type.data_type == DataType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f32"), - {stmt->dest->value, stmt->val->value}); + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else if (stmt->val->ret_type.data_type == DataType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f64"), - {stmt->dest->value, stmt->val->value}); + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::max) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::Max, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else if (stmt->val->ret_type.data_type == DataType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f32"), - {stmt->dest->value, stmt->val->value}); + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else if (stmt->val->ret_type.data_type == DataType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f64"), - {stmt->dest->value, stmt->val->value}); + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_and) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::And, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::And, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_or) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::Or, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::Or, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::bit_xor) { if (is_integral(stmt->val->ret_type.data_type)) { old_value = builder->CreateAtomicRMW( - llvm::AtomicRMWInst::BinOp::Xor, stmt->dest->value, - stmt->val->value, llvm::AtomicOrdering::SequentiallyConsistent); + llvm::AtomicRMWInst::BinOp::Xor, llvm_val[stmt->dest], + llvm_val[stmt->val], + llvm::AtomicOrdering::SequentiallyConsistent); } else { TI_NOT_IMPLEMENTED } } else { TI_NOT_IMPLEMENTED } - stmt->value = old_value; + llvm_val[stmt] = old_value; } } void visit(RandStmt *stmt) override { - stmt->value = + llvm_val[stmt] = create_call(fmt::format("cuda_rand_{}", data_type_short_name(stmt->ret_type.data_type)), {get_context()}); diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 5bbfc607c77b6..28e87672913e7 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -516,7 +516,6 @@ class Stmt : public IRNode { bool erased; std::string tb; Stmt *adjoint; - llvm::Value *value; bool is_ptr; Stmt(const Stmt &stmt) = delete; @@ -524,7 +523,6 @@ class Stmt : public IRNode { Stmt() { adjoint = nullptr; parent = nullptr; - value = nullptr; instance_id = instance_id_counter++; id = instance_id; operand_bitmap = 0;