From f095595fc6ca7b8ec760be8ae2094bff1d38ec40 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Tue, 19 Oct 2021 01:07:01 -0700 Subject: [PATCH] [Codegen][LLVM] Add ability to turn on fast math flags (#9223) * flags to turn off and on * turn fast math on always * llvm more opts * move to default codegen opt * TODO * add fast math options to llvm target * move to using new target attributes * llvm fast math target opt code * add -O flags * fix todo lint * support llvm 4.0, 5.0 * use same opt level as target machine * revert TargetOptions * fix thing * prevent regression in llvm * togglable opt-levels Co-authored-by: Andrew Zhao Luo --- src/target/llvm/codegen_llvm.cc | 25 ++++++++++++++++-- src/target/llvm/codegen_llvm.h | 7 +++++ src/target/llvm/llvm_common.cc | 20 ++++++++++++-- src/target/llvm/llvm_module.cc | 47 ++++++++++++++++++++++++++++++++- src/target/target_kind.cc | 9 +++++++ 5 files changed, 103 insertions(+), 5 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index c94c5a685d1b..6c64f6798e47 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -77,6 +77,8 @@ void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, this->InitTarget(tm); } +void CodeGenLLVM::SetFastMathFlag(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); } + void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { module_->setTargetTriple(tm->getTargetTriple().str()); module_->setDataLayout(tm->createDataLayout()); @@ -343,7 +345,26 @@ void CodeGenLLVM::Optimize() { // place optimization pass llvm::PassManagerBuilder builder; - builder.OptLevel = 3; + + // Use the same opt-level as specified in TargetMachine for running passes + llvm::CodeGenOpt::Level opt_level = target_machine_->getOptLevel(); + + switch (opt_level) { + case llvm::CodeGenOpt::Level::None: + builder.OptLevel = 0; + break; + case llvm::CodeGenOpt::Level::Less: + builder.OptLevel = 1; + break; + + case llvm::CodeGenOpt::Level::Default: + builder.OptLevel = 2; + break; + + default: + // CodeGenOpt::Level::Aggressive + builder.OptLevel = 3; + } #if TVM_LLVM_VERSION >= 50 builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false); @@ -410,7 +431,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { } else { return etype; } -} +} // namespace codegen llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { if (auto* ptr = type.as()) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 177b53056354..4a9df65951c0 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -78,6 +78,13 @@ class CodeGenLLVM : public ExprFunctor, */ virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, bool target_c_runtime); + + /*! + * \brief Turn on fast math flags for floating point operations. + * \param fmf FastMathFlags to use for code generation. + */ + void SetFastMathFlag(llvm::FastMathFlags fmf); + /*! * \brief Compile and add function f to the current module. * \param f The function to be added. diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index be80a8bc767e..06b2be2d9fb6 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -106,6 +106,8 @@ void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::stri #if TVM_LLVM_VERSION < 50 opt.LessPreciseFPMADOption = true; #endif + // In clang, these are fed from LangOpts which describe language specific features + // TODO(AndrewZhaoLuo): figure out how these relate to fast math flags opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; @@ -139,8 +141,22 @@ std::unique_ptr GetLLVMTargetMachine(const Target& target, ICHECK(allow_null) << err << " target_triple=" << target_triple; return nullptr; } - llvm::TargetMachine* tm = - llvm_target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_); + + Integer llvm_opt_level = target->GetAttr("opt-level").value_or(Integer(3)); + llvm::CodeGenOpt::Level llvm_opt; + if (llvm_opt_level <= 0) { + llvm_opt = llvm::CodeGenOpt::None; + } else if (llvm_opt_level == 1) { + llvm_opt = llvm::CodeGenOpt::Less; + } else if (llvm_opt_level == 2) { + llvm_opt = llvm::CodeGenOpt::Default; + } else { + // llvm_opt_level >= 3 + llvm_opt = llvm::CodeGenOpt::Aggressive; + } + + llvm::TargetMachine* tm = llvm_target->createTargetMachine( + target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_, llvm::CodeModel::Small, llvm_opt); return std::unique_ptr(tm); } diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 0e4bca4396f5..657778df0e93 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -258,8 +258,53 @@ class LLVMModuleNode final : public runtime::ModuleNode { // makes sense when we start to use multiple modules. cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); - cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); + // See https://llvm.org/docs/LangRef.html#fast-math-flags for details + Bool fast_math_all = target->GetAttr("fast-math").value_or(Bool(false)); + Bool fast_math_nnan = target->GetAttr("fast-math-nnan").value_or(Bool(false)); + Bool fast_math_ninf = target->GetAttr("fast-math-ninf").value_or(Bool(false)); + Bool fast_math_nsz = target->GetAttr("fast-math-nsz").value_or(Bool(false)); + Bool fast_math_arcp = target->GetAttr("fast-math-arcp").value_or(Bool(false)); + + llvm::FastMathFlags fmf; + if (fast_math_all) { +#if TVM_LLVM_VERSION >= 60 + fmf.setFast(); +#else + fmf.setUnsafeAlgebra(); +#endif + } + + if (fast_math_nnan) { + fmf.setNoNaNs(); + } + if (fast_math_ninf) { + fmf.setNoInfs(); + } + if (fast_math_nsz) { + fmf.setNoSignedZeros(); + } + if (fast_math_arcp) { + fmf.setAllowReciprocal(); + } + +#if TVM_LLVM_VERSION >= 60 + Bool fast_math_contract = target->GetAttr("fast-math-contract").value_or(Bool(false)); + Bool fast_math_afn = target->GetAttr("fast-math-afn").value_or(Bool(false)); + Bool fast_math_reassoc = target->GetAttr("fast-math-reassoc").value_or(Bool(false)); + if (fast_math_contract) { + fmf.setAllowContract(); + } + if (fast_math_afn) { + fmf.setApproxFunc(); + } + if (fast_math_reassoc) { + fmf.setAllowReassoc(); + } +#endif + cg->SetFastMathFlag(fmf); + + cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 7cd329f83738..4403af26d1a8 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -230,6 +230,15 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("link-params", Bool(false)) .add_attr_option("unpacked-api") .add_attr_option("interface-api") + // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags + .add_attr_option("fast-math") // implies all the below + .add_attr_option("fast-math-nnan") + .add_attr_option("fast-math-ninf") + .add_attr_option("fast-math-nsz") + .add_attr_option("fast-math-arcp") + .add_attr_option("fast-math-contract") + .add_attr_option("fast-math-reassoc") + .add_attr_option("opt-level") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("c", kDLCPU)