Skip to content

Commit

Permalink
[Codegen][LLVM] Add ability to turn on fast math flags (#9223)
Browse files Browse the repository at this point in the history
* flags to turn off and on

* turn fast math on always

* llvm more opts

* move to default codegen opt

* TODO

* add fast math options to llvm target

* move to using new target attributes

* llvm fast math target opt code

* add -O flags

* fix todo lint

* support llvm 4.0, 5.0

* use same opt level as target machine

* revert TargetOptions

* fix thing

* prevent regression in llvm

* togglable opt-levels

Co-authored-by: Andrew Zhao Luo <[email protected]>
  • Loading branch information
AndrewZhaoLuo and Andrew Zhao Luo authored Oct 19, 2021
1 parent 75cf964 commit f095595
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 5 deletions.
25 changes: 23 additions & 2 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm,
this->InitTarget(tm);
}

void CodeGenLLVM::SetFastMathFlag(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); }

void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
module_->setTargetTriple(tm->getTargetTriple().str());
module_->setDataLayout(tm->createDataLayout());
Expand Down Expand Up @@ -343,7 +345,26 @@ void CodeGenLLVM::Optimize() {

// place optimization pass
llvm::PassManagerBuilder builder;
builder.OptLevel = 3;

// Use the same opt-level as specified in TargetMachine for running passes
llvm::CodeGenOpt::Level opt_level = target_machine_->getOptLevel();

switch (opt_level) {
case llvm::CodeGenOpt::Level::None:
builder.OptLevel = 0;
break;
case llvm::CodeGenOpt::Level::Less:
builder.OptLevel = 1;
break;

case llvm::CodeGenOpt::Level::Default:
builder.OptLevel = 2;
break;

default:
// CodeGenOpt::Level::Aggressive
builder.OptLevel = 3;
}

#if TVM_LLVM_VERSION >= 50
builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false);
Expand Down Expand Up @@ -410,7 +431,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
} else {
return etype;
}
}
} // namespace codegen

llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const {
if (auto* ptr = type.as<PrimTypeNode>()) {
Expand Down
7 changes: 7 additions & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*/
virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx,
bool system_lib, bool dynamic_lookup, bool target_c_runtime);

/*!
* \brief Turn on fast math flags for floating point operations.
* \param fmf FastMathFlags to use for code generation.
*/
void SetFastMathFlag(llvm::FastMathFlags fmf);

/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
Expand Down
20 changes: 18 additions & 2 deletions src/target/llvm/llvm_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::stri
#if TVM_LLVM_VERSION < 50
opt.LessPreciseFPMADOption = true;
#endif
// In clang, these are fed from LangOpts which describe language specific features
// TODO(AndrewZhaoLuo): figure out how these relate to fast math flags
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
Expand Down Expand Up @@ -139,8 +141,22 @@ std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const Target& target,
ICHECK(allow_null) << err << " target_triple=" << target_triple;
return nullptr;
}
llvm::TargetMachine* tm =
llvm_target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_);

Integer llvm_opt_level = target->GetAttr<Integer>("opt-level").value_or(Integer(3));
llvm::CodeGenOpt::Level llvm_opt;
if (llvm_opt_level <= 0) {
llvm_opt = llvm::CodeGenOpt::None;
} else if (llvm_opt_level == 1) {
llvm_opt = llvm::CodeGenOpt::Less;
} else if (llvm_opt_level == 2) {
llvm_opt = llvm::CodeGenOpt::Default;
} else {
// llvm_opt_level >= 3
llvm_opt = llvm::CodeGenOpt::Aggressive;
}

llvm::TargetMachine* tm = llvm_target->createTargetMachine(
target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_, llvm::CodeModel::Small, llvm_opt);
return std::unique_ptr<llvm::TargetMachine>(tm);
}

Expand Down
47 changes: 46 additions & 1 deletion src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,53 @@ class LLVMModuleNode final : public runtime::ModuleNode {
// makes sense when we start to use multiple modules.
cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime);

cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
// See https://llvm.org/docs/LangRef.html#fast-math-flags for details
Bool fast_math_all = target->GetAttr<Bool>("fast-math").value_or(Bool(false));
Bool fast_math_nnan = target->GetAttr<Bool>("fast-math-nnan").value_or(Bool(false));
Bool fast_math_ninf = target->GetAttr<Bool>("fast-math-ninf").value_or(Bool(false));
Bool fast_math_nsz = target->GetAttr<Bool>("fast-math-nsz").value_or(Bool(false));
Bool fast_math_arcp = target->GetAttr<Bool>("fast-math-arcp").value_or(Bool(false));

llvm::FastMathFlags fmf;
if (fast_math_all) {
#if TVM_LLVM_VERSION >= 60
fmf.setFast();
#else
fmf.setUnsafeAlgebra();
#endif
}

if (fast_math_nnan) {
fmf.setNoNaNs();
}
if (fast_math_ninf) {
fmf.setNoInfs();
}
if (fast_math_nsz) {
fmf.setNoSignedZeros();
}
if (fast_math_arcp) {
fmf.setAllowReciprocal();
}

#if TVM_LLVM_VERSION >= 60
Bool fast_math_contract = target->GetAttr<Bool>("fast-math-contract").value_or(Bool(false));
Bool fast_math_afn = target->GetAttr<Bool>("fast-math-afn").value_or(Bool(false));
Bool fast_math_reassoc = target->GetAttr<Bool>("fast-math-reassoc").value_or(Bool(false));
if (fast_math_contract) {
fmf.setAllowContract();
}
if (fast_math_afn) {
fmf.setApproxFunc();
}
if (fast_math_reassoc) {
fmf.setAllowReassoc();
}
#endif

cg->SetFastMathFlag(fmf);

cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
if (entry_func.length() != 0) {
cg->AddMainFunction(entry_func);
}
Expand Down
9 changes: 9 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,15 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<Bool>("link-params", Bool(false))
.add_attr_option<Bool>("unpacked-api")
.add_attr_option<String>("interface-api")
// Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags
.add_attr_option<Bool>("fast-math") // implies all the below
.add_attr_option<Bool>("fast-math-nnan")
.add_attr_option<Bool>("fast-math-ninf")
.add_attr_option<Bool>("fast-math-nsz")
.add_attr_option<Bool>("fast-math-arcp")
.add_attr_option<Bool>("fast-math-contract")
.add_attr_option<Bool>("fast-math-reassoc")
.add_attr_option<Integer>("opt-level")
.set_default_keys({"cpu"});

TVM_REGISTER_TARGET_KIND("c", kDLCPU)
Expand Down

0 comments on commit f095595

Please sign in to comment.