Skip to content

Commit

Permalink
Add new pass for math to rocdl. (llvm#93)
Browse files Browse the repository at this point in the history
Add new pass for math to rocdl.
  • Loading branch information
jsjodin authored Jun 18, 2024
1 parent c01ecf8 commit 9167cb0
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 1 deletion.
1 change: 1 addition & 0 deletions flang/lib/Optimizer/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen
MLIRMathToFuncs
MLIRMathToLLVM
MLIRMathToLibm
MLIRMathToROCDL
MLIROpenMPToLLVM
MLIRBuiltinToLLVMIRTranslation
MLIRLLVMToLLVMIRTranslation
Expand Down
12 changes: 11 additions & 1 deletion flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
Expand Down Expand Up @@ -3609,6 +3610,14 @@ class FIRToLLVMLowering
// as passes here.
mlir::OpPassManager mathConvertionPM("builtin.module");

bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
// If compiling for AMD target some math operations must be lowered to AMD
// GPU library calls, the rest can be converted to LLVM intrinsics, which
// is handled in the mathToLLVM conversion. The lowering to libm calls is
// not needed since all math operations are handled this way.
if (isAMDGCN)
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());

// Convert math::FPowI operations to inline implementation
// only if the exponent's width is greater than 32, otherwise,
// it will be lowered to LLVM intrinsic operation by a later conversion.
Expand Down Expand Up @@ -3648,7 +3657,8 @@ class FIRToLLVMLowering
pattern);
// Math operations that have not been converted yet must be converted
// to Libm.
mlir::populateMathToLibmConversionPatterns(pattern);
if (!isAMDGCN)
mlir::populateMathToLibmConversionPatterns(pattern);
mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern);

Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- MathToROCDL.h - Utils to convert from the complex dialect --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTMATHTOROCDL
#include "mlir/Conversion/Passes.h.inc"

} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
Expand Down
16 changes: 16 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,22 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
];
}

//===----------------------------------------------------------------------===//
// MathToROCDL
//===----------------------------------------------------------------------===//

def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
let summary = "Convert Math dialect to rocdl calls";
let description = [{
This pass converts supported Math ops to rocdl calls.
}];
let dependentDialects = [
"func::FuncDialect",
"math::MathDialect",
"vector::VectorDialect",
];
}

//===----------------------------------------------------------------------===//
// MathToSPIRV
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_subdirectory(LLVMCommon)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
add_mlir_conversion_library(MLIRMathToROCDL
MathToROCDL.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRDialectUtils
MLIRFuncDialect
MLIRGPUToGPURuntimeTransforms
MLIRMathDialect
MLIRLLVMCommonConversion
MLIRPass
MLIRTransformUtils
MLIRVectorDialect
MLIRVectorUtils
)
147 changes: 147 additions & 0 deletions mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOROCDL
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

#define DEBUG_TYPE "math-to-rocdl"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func) {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
}

static void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::CopySignOp
// Handled by mathToLLVM: math::CountLeadingZerosOp
// Handled by mathToLLVM: math::CountTrailingZerosOp
// Handled by mathToLLVM: math::CgPopOp
// Handled by mathToLLVM: math::FmaOp
// FIXME: math::IPowIOp
// FIXME: math::FPowIOp
// Handled by mathToLLVM: math::RoundEvenOp
// Handled by mathToLLVM: math::RoundOp
// Handled by mathToLLVM: math::TruncOp

populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
"__ocml_fabs_f64");
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
"__ocml_acos_f64");
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
"__ocml_acosh_f64");
populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
"__ocml_asin_f64");
populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
"__ocml_asinh_f64");
populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
"__ocml_atan_f64");
populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
"__ocml_atanh_f64");
populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
"__ocml_atan2_f64");
populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
"__ocml_cbrt_f64");
populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
"__ocml_ceil_f64");
populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
"__ocml_cos_f64");
populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
"__ocml_cosh_f64");
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
"__ocml_sinh_f64");
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
"__ocml_exp_f64");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
"__ocml_exp2_f64");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
"__ocml_expm1_f64");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
"__ocml_floor_f64");
// FIXME: Different pass or new op in math?
// populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
// "__ocml_fmod_f64");
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
"__ocml_log_f64");
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
"__ocml_log10_f64");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
"__ocml_log1p_f64");
populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
"__ocml_log2_f64");
populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
"__ocml_pow_f64");
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
"__ocml_rsqrt_f64");
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
"__ocml_sin_f64");
populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
"__ocml_sqrt_f64");
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
"__ocml_tanh_f64");
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
"__ocml_tan_f64");
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
"__ocml_erf_f64");
}

namespace {
struct ConvertMathToROCDLPass
: public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
ConvertMathToROCDLPass() = default;
void runOnOperation() override;
};
} // namespace

void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
MLIRContext *ctx = m.getContext();


RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
populateMathToROCDLConversionPatterns(converter, patterns);

ConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
LLVM::SqrtOp>();
if (failed(applyPartialConversion(m, target, std::move(patterns))))
signalPassFailure();
}

0 comments on commit 9167cb0

Please sign in to comment.