Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable using bare pointers for GPU kernels with static shape, and then use that support #690

Merged
merged 4 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ using LoweringCallback = std::function<std::unique_ptr<llvm::Module>(
/// This pass does not generate code to call GPU runtime APIs directly but
/// instead uses a small wrapper library that exports a stable and conveniently
/// typed ABI on top of GPU runtimes such as CUDA or ROCm (HIP).
std::unique_ptr<OperationPass<ModuleOp>> createGpuToLLVMConversionPass();
std::unique_ptr<OperationPass<ModuleOp>>
createGpuToLLVMConversionPass(bool kernelBarePtrCallConv = false);

/// Collect a set of patterns to convert from the GPU dialect to LLVM and
/// populate converter for gpu types.
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
StringRef gpuBinaryAnnotation = {});
StringRef gpuBinaryAnnotation = {},
bool kernelBarePtrCallConv = false);

} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
createLowerGpuOpsToROCDLOpsPass(
const std::string &chipset = "gfx900",
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
bool useBarePtrCallConv = false,
gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown);

} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ class LLVMTypeConverter : public TypeConverter {

const LowerToLLVMOptions &getOptions() const { return options; }

/// Set the lowering options to `newOptions`. Note: using this after some
/// some conversions have been performed can lead to inconsistencies in the
/// IR.
void dangerousSetOptions(LowerToLLVMOptions newOptions) {
options = std::move(newOptions);
}

/// Promote the LLVM representation of all operands including promoting MemRef
/// descriptors to stack and use pointers to struct to avoid the complexity
/// of the platform-specific C/C++ ABI lowering related to struct argument
Expand Down Expand Up @@ -126,7 +133,7 @@ class LLVMTypeConverter : public TypeConverter {
const DataLayout &layout);

/// Check if a memref type can be converted to a bare pointer.
bool canConvertToBarePtr(BaseMemRefType type);
static bool canConvertToBarePtr(BaseMemRefType type);

protected:
/// Pointer to the LLVM dialect.
Expand Down
4 changes: 4 additions & 0 deletions external/llvm-project/mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,10 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
"Bitwidth of the index type, 0 to use size of machine word">,
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
/*default=*/"false",
"Replace memref arguments in GPU functions with bare pointers."
"All memrefs must have static shape">,
Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
"::mlir::gpu::amd::Runtime::Unknown",
"Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "GPUOpsLowering.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;
Expand Down Expand Up @@ -142,6 +143,34 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
&signatureConversion)))
return failure();

// If bare memref pointers are being used, remap them back to memref
// descriptors This must be done after signature conversion to get rid of the
// unrealized casts.
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
auto memrefTy = en.value().dyn_cast<MemRefType>();
if (!memrefTy)
continue;
assert(memrefTy.hasStaticShape() &&
"Bare pointer convertion used with dynamically-shaped memrefs");
// Use a placeholder when replacing uses of the memref argument to prevent
// circular replacements.
auto remapping = signatureConversion.getInputMapping(en.index());
assert(remapping && remapping->size == 1 &&
"Type converter should produce 1-to-1 mapping for bare memrefs");
BlockArgument newArg =
llvmFuncOp.getBody().getArgument(remapping->inputNo);
auto placeholder = rewriter.create<LLVM::UndefOp>(
loc, getTypeConverter()->convertType(memrefTy));
rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
Value desc = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), memrefTy, newArg);
rewriter.replaceOp(placeholder, {desc});
}
}

rewriter.eraseOp(gpuFuncOp);
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class GpuToLLVMConversionPass
public:
GpuToLLVMConversionPass() = default;

GpuToLLVMConversionPass(bool kernelBarePtrCallConv)
: GpuToLLVMConversionPass() {
if (this->kernelBarePtrCallConv.getNumOccurrences() == 0)
this->kernelBarePtrCallConv = kernelBarePtrCallConv;
}

GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
: GpuToLLVMConversionPassBase(other) {}

Expand All @@ -60,6 +66,11 @@ class GpuToLLVMConversionPass
*this, "gpu-binary-annotation",
llvm::cl::desc("Annotation attribute string for GPU binary"),
llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
Option<bool> kernelBarePtrCallConv{
*this, "use-bare-pointers-for-kernels",
llvm::cl::desc("Use bare pointers to pass memref arguments to kernels. "
"The kernel must use the same setting for this option."),
llvm::cl::init(false)};
};

struct FunctionCallBuilder {
Expand Down Expand Up @@ -290,9 +301,11 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
public:
ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
StringRef gpuBinaryAnnotation)
StringRef gpuBinaryAnnotation,
bool kernelBarePtrCallConv)
: ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
gpuBinaryAnnotation(gpuBinaryAnnotation) {}
gpuBinaryAnnotation(gpuBinaryAnnotation),
kernelBarePtrCallConv(kernelBarePtrCallConv) {}

private:
Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
Expand All @@ -305,6 +318,7 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
ConversionPatternRewriter &rewriter) const override;

llvm::SmallString<32> gpuBinaryAnnotation;
bool kernelBarePtrCallConv;
};

class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
Expand Down Expand Up @@ -377,7 +391,8 @@ void GpuToLLVMConversionPass::runOnOperation() {
populateFuncToLLVMConversionPatterns(converter, patterns);
populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
kernelBarePtrCallConv);

if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
Expand Down Expand Up @@ -635,9 +650,24 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
auto loc = launchOp.getLoc();
auto numKernelOperands = launchOp.getNumKernelOperands();
auto arguments = getTypeConverter()->promoteOperands(
loc, launchOp.getOperands().take_back(numKernelOperands),
adaptor.getOperands().take_back(numKernelOperands), builder);
SmallVector<Value, 4> arguments;
if (kernelBarePtrCallConv) {
// Hack the bare pointer value on just for the argument promotion
LLVMTypeConverter *converter = getTypeConverter();
LowerToLLVMOptions options = converter->getOptions();
LowerToLLVMOptions overrideToMatchKernelOpts = options;
overrideToMatchKernelOpts.useBarePtrCallConv = true;
converter->dangerousSetOptions(overrideToMatchKernelOpts);
arguments = converter->promoteOperands(
loc, launchOp.getOperands().take_back(numKernelOperands),
adaptor.getOperands().take_back(numKernelOperands), builder);
converter->dangerousSetOptions(options);
} else {
arguments = getTypeConverter()->promoteOperands(
loc, launchOp.getOperands().take_back(numKernelOperands),
adaptor.getOperands().take_back(numKernelOperands), builder);
}

auto numArguments = arguments.size();
SmallVector<Type, 4> argumentTypes;
argumentTypes.reserve(numArguments);
Expand Down Expand Up @@ -873,13 +903,14 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
}

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
mlir::createGpuToLLVMConversionPass() {
return std::make_unique<GpuToLLVMConversionPass>();
mlir::createGpuToLLVMConversionPass(bool kernelBarePtrCallConv) {
return std::make_unique<GpuToLLVMConversionPass>(kernelBarePtrCallConv);
}

void mlir::populateGpuToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
StringRef gpuBinaryAnnotation) {
void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
StringRef gpuBinaryAnnotation,
bool kernelBarePtrCallConv) {
converter.addConversion(
[context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
Expand All @@ -893,7 +924,7 @@ void mlir::populateGpuToLLVMConversionPatterns(
ConvertWaitAsyncOpToGpuRuntimeCallPattern,
ConvertWaitOpToGpuRuntimeCallPattern,
ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
gpuBinaryAnnotation);
patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
converter, gpuBinaryAnnotation, kernelBarePtrCallConv);
patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@

using namespace mlir;

/// Returns true if the given `gpu.func` can be safely called using the bare
/// pointer calling convention.
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
bool canBeBare = true;
for (Type type : func.getArgumentTypes())
if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
return canBeBare;
}

namespace {

/// Import the GPU Ops to ROCDL Patterns.
Expand All @@ -70,10 +80,16 @@ struct LowerGpuOpsToROCDLOpsPass
: public ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
LowerGpuOpsToROCDLOpsPass() = default;
LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
bool useBarePtrCallConv,
gpu::amd::Runtime runtime) {
this->chipset = chipset;
this->indexBitwidth = indexBitwidth;
this->runtime = runtime;
if (this->chipset.getNumOccurrences() == 0)
this->chipset = chipset;
if (this->indexBitwidth.getNumOccurrences() == 0)
this->indexBitwidth = indexBitwidth;
if (this->useBarePtrCallConv.getNumOccurrences() == 0)
this->useBarePtrCallConv = useBarePtrCallConv;
if (this->runtime.getNumOccurrences() == 0)
this->runtime = runtime;
}

void runOnOperation() override {
Expand All @@ -97,7 +113,23 @@ struct LowerGpuOpsToROCDLOpsPass
ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);


if (useBarePtrCallConv) {
options.useBarePtrCallConv = true;
WalkResult canUseBarePointers =
m.walk([](gpu::GPUFuncOp func) -> WalkResult {
if (canBeCalledWithBarePointers(func))
return WalkResult::advance();
return WalkResult::interrupt();
});
if (canUseBarePointers.wasInterrupted()) {
emitError(UnknownLoc::get(ctx),
"bare pointer calling convention requires all memrefs to "
"have static shape and use the identity map");
return signalPassFailure();
}
}

LLVMTypeConverter converter(ctx, options);

RewritePatternSet patterns(ctx);
Expand Down Expand Up @@ -255,7 +287,8 @@ void mlir::populateGpuToROCDLConversionPatterns(
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
unsigned indexBitwidth,
bool useBarePtrCallConv,
gpu::amd::Runtime runtime) {
return std::make_unique<LowerGpuOpsToROCDLOpsPass>(chipset, indexBitwidth,
runtime);
return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
chipset, indexBitwidth, useBarePtrCallConv, runtime);
}
15 changes: 15 additions & 0 deletions external/llvm-project/mlir/test/Conversion/GPUToROCDL/memref.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
// RUN: mlir-opt %s \
// RUN: -convert-gpu-to-rocdl=use-bare-ptr-memref-call-conv=true \
// RUN: -split-input-file \
// RUN: | FileCheck %s --check-prefix=BARE

gpu.module @memref_conversions {
// CHECK: llvm.func @kern
// CHECK-SAME: (%{{.*}}: !llvm.ptr<f32>, %{{.*}}: !llvm.ptr<f32>, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64)
// BARE: llvm.func @kern
// BARE-SAME: (%{{.*}}: !llvm.ptr<f32>)
gpu.func @kern(%arg0: memref<8xf32>) kernel {
gpu.return
}
}
19 changes: 11 additions & 8 deletions external/llvm-project/mlir/test/Integration/GPU/ROCM/vecadd.mlir
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
// RUN: mlir-opt %s \
// RUN: -convert-scf-to-cf \
// RUN: -gpu-kernel-outlining \
// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-rocdl,gpu-to-hsaco{chip=%chip})' \
// RUN: -gpu-to-llvm \
// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-rocdl{use-bare-ptr-memref-call-conv=true},gpu-to-hsaco{chip=%chip})' \
// RUN: -gpu-to-llvm=use-bare-pointers-for-kernels=true \
// RUN: | mlir-cpu-runner \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
// RUN: --entry-point-result=void \
// RUN: | FileCheck %s

func.func @vecadd(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>, %arg2 : memref<?xf32>) {
func.func @vecadd(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %arg2 : memref<5xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%block_dim = memref.dim %arg0, %c0 : memref<?xf32>
%block_dim = arith.constant 5 : index
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %block_dim, %block_y = %c1, %block_z = %c1) {
%a = memref.load %arg0[%tx] : memref<?xf32>
%b = memref.load %arg1[%tx] : memref<?xf32>
%a = memref.load %arg0[%tx] : memref<5xf32>
%b = memref.load %arg1[%tx] : memref<5xf32>
%c = arith.addf %a, %b : f32
memref.store %c, %arg2[%tx] : memref<?xf32>
memref.store %c, %arg2[%tx] : memref<5xf32>
gpu.terminator
}
return
Expand Down Expand Up @@ -49,8 +49,11 @@ func.func @main() {
%9 = call @mgpuMemGetDeviceMemRef1dFloat(%3) : (memref<?xf32>) -> (memref<?xf32>)
%10 = call @mgpuMemGetDeviceMemRef1dFloat(%4) : (memref<?xf32>) -> (memref<?xf32>)
%11 = call @mgpuMemGetDeviceMemRef1dFloat(%5) : (memref<?xf32>) -> (memref<?xf32>)
%12 = memref.cast %9 : memref<?xf32> to memref<5xf32>
%13 = memref.cast %10 : memref<?xf32> to memref<5xf32>
%14 = memref.cast %11 : memref<?xf32> to memref<5xf32>

call @vecadd(%9, %10, %11) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
call @vecadd(%12, %13, %14) : (memref<5xf32>, memref<5xf32>, memref<5xf32>) -> ()
call @printMemrefF32(%8) : (memref<*xf32>) -> ()
return
}
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir-c/Dialect/MIGraphX.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
extern "C" {
#endif

// Version 2: Use bare pointer ABI (kernels take just a pointer to the data
// buffer, not an entire memref struct). Also introduces this constant.
#define MLIR_MIGRAPHX_DIALECT_API_VERSION 2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a comment here to keep the features added/deleted per each version?

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(MIGraphX, migraphx);

// Phase 0 functions : Assuming the given module contains only one function
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/MIOpen/XMIRPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_XMIR_PIPELINES_H_

#include "mlir/Pass/PassOptions.h"
#include "llvm/Support/CommandLine.h"

using namespace mlir::detail;
using namespace llvm::cl;
Expand All @@ -37,6 +38,10 @@ struct RunnerOptions : public PassPipelineOptions<RunnerOptions> {

PassOptions::Option<bool> cpuOnly{
*this, "cpu-only", desc("Generate CPU-only code "), init(false)};

PassOptions::Option<bool> barePtrMemrefs{
*this, "bare-ptr-memref-kernels",
desc("Use bare pointers to pass memrefs to GPU kernels"), init(true)};
};

/// Build the XMIR Runner Pipeline.
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/MIOpen/Pipelines/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ void miopen::buildBackendPipeline(OpPassManager &pm,
* "--gpu-to-hsaco=triple=$triple chip=$chip features=$features opt-level=3"
*/
pm.addPass(createStripDebugInfoPass());
pm.addPass(
createLowerGpuOpsToROCDLOpsPass(options.chip, options.indexBitwidth));
pm.addPass(createLowerGpuOpsToROCDLOpsPass(
options.chip, options.indexBitwidth, /*useBarePtrCallConv=*/true));
pm.addPass(createGpuSerializeToHsacoPass(options.triple, options.chip,
options.features, options.optLevel));
}
Expand Down
Loading