Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMDGPU ukernels: Bazel build, separate bitcode files, c-embed archives. #19274

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Testing/

# Bazel artifacts
**/bazel-*
MODULE.bazel
MODULE.bazel.lock

# Executables
*.exe
Expand Down
74 changes: 74 additions & 0 deletions build_tools/bazel/iree_bitcode_library.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,80 @@ def iree_cuda_bitcode_library(
**kwargs
)

def iree_amdgpu_bitcode_library(
name,
gpu_arch,
srcs,
copts = [],
out = None,
**kwargs):
"""Builds an AMDGPU LLVM bitcode library from an input file using clang.

Args:
name: Name of the target.
gpu_arch: Target AMDGPU architecture, e.g. gfx942.
srcs: Source files to pass to clang. Headers (*.h) are for dependency
tracking only. Current limitation: only one non-header source is
supported.
copts: Additional flags to pass to clang.
out: Output file name. Defaults to {source.c}.{gpu_arch}.bc.
**kwargs: any additional attributes to pass to the underlying rules.
"""

clang_tool = "@llvm-project//clang:clang"

base_copts = [
# Language: C23.
"-std=c23",

# Avoid dependencies.
"-nogpulib",

# Avoid ABI issues.
"-fno-short-wchar", # Shouldn't matter to us, but doesn't hurt.

# Target architecture/machine.
"-target",
"amdgcn-amd-amdhsa",
"-march=%s" % gpu_arch,
"-fgpu-rdc", # NOTE: may not be required for all targets.

# Optimized.
"-O3",
"-fno-ident",
"-fvisibility=hidden",

# Object file only in bitcode format.
"-c",
"-emit-llvm",
]

non_header_srcs = [src for src in srcs if not src.endswith(".h")]
if len(non_header_srcs) != 1:
fail("Expected exactly one non-header file in srcs, got srcs=[" + ", ".join(srcs) + "]")
src = non_header_srcs[0]

if not out:
out = "%s.%s.bc" % (src, gpu_arch)

native.genrule(
name = "gen_%s" % (out),
srcs = srcs,
outs = [out],
cmd = " ".join([
"$(location %s)" % (clang_tool),
"$(location %s)" % (src),
"-o $(location %s)" % (out),
"-I .",
] + base_copts + copts),
tools = [
clang_tool,
],
message = "Compiling %s to %s..." % (src, out),
output_to_bindir = 1,
**kwargs
)

def iree_link_bitcode(
name,
bitcode_files,
Expand Down
19 changes: 19 additions & 0 deletions build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,25 @@ def iree_cuda_bitcode_library(
f")\n\n"
)

def iree_amdgpu_bitcode_library(self, name, gpu_arch, srcs, copts=None, out=None):
name_block = self._convert_string_arg_block("NAME", name, quote=False)
gpu_arch_block = self._convert_string_arg_block(
"GPU_ARCH", gpu_arch, quote=False
)
srcs_block = self._convert_srcs_block(srcs)
out_block = self._convert_string_arg_block("OUT", out, quote=False)
copts_block = self._convert_string_list_block("COPTS", copts, sort=False)

self._converter.body += (
f"iree_amdgpu_bitcode_library(\n"
f"{name_block}"
f"{gpu_arch_block}"
f"{srcs_block}"
f"{out_block}"
f"{copts_block}"
f")\n\n"
)

def iree_link_bitcode(self, name, bitcode_files):
name_block = self._convert_string_arg_block("NAME", name, quote=False)
bitcode_files_block = self._convert_srcs_block(
Expand Down
92 changes: 92 additions & 0 deletions build_tools/cmake/iree_bitcode_library.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,98 @@ function(iree_cuda_bitcode_library)
)
endfunction()

# iree_amdgpu_bitcode_library()
#
# Builds an AMDGPU LLVM bitcode library from an input file via clang.
#
# Parameters:
# NAME: Name of the target.
# GPU_ARCH: Target AMDGPU architecture, e.g. gfx942.
# SRCS: Source files to pass to clang. Headers (*.h) are for dependency
# tracking only. Current limitation: only one non-header source is
# supported.
# COPTS: Additional flags to pass to clang.
# OUT: Output file name. Defaults to {source.c}.{gpu_arch}.bc.
#
function(iree_amdgpu_bitcode_library)
bjacob marked this conversation as resolved.
Show resolved Hide resolved
cmake_parse_arguments(
_RULE
""
"NAME;OUT;GPU_ARCH"
"SRCS;COPTS"
${ARGN}
)

set(_SRC "")
foreach(_SRCS_ENTRY IN LISTS _RULE_SRCS)
if(_SRCS_ENTRY MATCHES "\.h$")
continue()
endif()
if (_SRC)
message(SEND_ERROR "Currently limitation: only one non-header file allowed in SRCS.")
endif()
set(_SRC "${_SRCS_ENTRY}")
endforeach()
if(NOT _SRC)
message(SEND_ERROR "Error: no non-header file found in SRCS=${_RULE_SRCS}.")
endif()

if(DEFINED _RULE_OUT)
set(_OUT "${_RULE_OUT}")
else()
set(_OUT "${_SRC}.${_RULE_GPU_ARCH}.bc")
endif()

set(_COPTS
# Language: C23
"-std=c23"

# Avoid dependencies.
"-nogpulib"

# Avoid ABI issues.
"-fno-short-wchar" # Shouldn't matter to us, but doesn't hurt.

# Target architecture/machine.
"-target"
"amdgcn-amd-amdhsa"
"-march=${_RULE_GPU_ARCH}"
"-fgpu-rdc" # NOTE: may not be required for all targets.

# Optimized.
"-O3"
"-fno-ident"
"-fvisibility=hidden"

# Object file only in bitcode format.
"-c"
"-emit-llvm"
)

add_custom_command(
OUTPUT
"${_OUT}"
COMMAND
"${IREE_CLANG_BINARY}"
${_COPTS}
"-I" "${IREE_SOURCE_DIR}"
"${CMAKE_CURRENT_SOURCE_DIR}/${_SRC}"
"-o" "${_OUT}"
DEPENDS
"${IREE_CLANG_BINARY}"
"${_RULE_SRCS}"
COMMENT
"Compiling ${_SRC} to ${_OUT}"
VERBATIM
)

# Only add iree_${NAME} as custom target doesn't support aliasing to
# iree::${NAME}.
iree_package_name(_PACKAGE_NAME)
add_custom_target("${_PACKAGE_NAME}_${_RULE_NAME}"
DEPENDS "${_OUT}"
)
endfunction()

# iree_link_bitcode()
#
Expand Down
4 changes: 4 additions & 0 deletions compiler/plugins/target/ROCM/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ iree_compiler_cc_library(
"ROCMTargetUtils.h",
],
deps = [
"//compiler/plugins/target/ROCM/builtins/ukernel:iree_uk_amdgpu_gfx1030",
"//compiler/plugins/target/ROCM/builtins/ukernel:iree_uk_amdgpu_gfx1100",
"//compiler/plugins/target/ROCM/builtins/ukernel:iree_uk_amdgpu_gfx90a",
"//compiler/plugins/target/ROCM/builtins/ukernel:iree_uk_amdgpu_gfx942",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets",
Expand Down
4 changes: 4 additions & 0 deletions compiler/plugins/target/ROCM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ iree_cc_library(
iree::compiler::Dialect::HAL::Utils::LLVMLinkerUtils
iree::compiler::PluginAPI
iree::compiler::Utils
iree::compiler::plugins::target::ROCM::builtins::ukernel::iree_uk_amdgpu_gfx1030
iree::compiler::plugins::target::ROCM::builtins::ukernel::iree_uk_amdgpu_gfx1100
iree::compiler::plugins::target::ROCM::builtins::ukernel::iree_uk_amdgpu_gfx90a
iree::compiler::plugins::target::ROCM::builtins::ukernel::iree_uk_amdgpu_gfx942
iree::schemas::amdgpu_executable_def_c_fbs
iree::schemas::executable_debug_info_c_fbs
iree::schemas::hip_executable_def_c_fbs
Expand Down
99 changes: 34 additions & 65 deletions compiler/plugins/target/ROCM/ROCMTargetUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@

#include "compiler/plugins/target/ROCM/ROCMTargetUtils.h"

#include "compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_gfx1030.h"
#include "compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_gfx1100.h"
#include "compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_gfx90a.h"
#include "compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_gfx942.h"
Comment on lines +9 to +12
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could group these into a single header file, perhaps with the contents of that header file generated by the build system.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is by construction one header file per c_embed_data library. So the question becomes whether to group these 4 c_embed_data libraries into one. I have a specific reason not to, which will become apparent in the next PR: I want the c_embed_data library for a given target architecture to contain exactly the ukernels that are supported on that architecture, so that just by looking up the table of contents, LowerToUkernels will be able to tell whether a ukernel lowering should succeed or fail.

#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.h"
#include "iree/compiler/Utils/ToolUtils.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
Expand Down Expand Up @@ -79,76 +84,28 @@ static LogicalResult linkWithBitcodeFiles(Location loc, llvm::Module *module,
}

static LogicalResult linkBitcodeFile(Location loc, llvm::Linker &linker,
unsigned linkerFlags, StringRef path,
unsigned linkerFlags, StringRef filename,
StringRef contents,
llvm::TargetMachine &targetMachine,
llvm::LLVMContext &context) {
auto bitcodeBufferRef = llvm::MemoryBuffer::getFile(path);
if (auto ec = bitcodeBufferRef.getError()) {
return mlir::emitError(loc) << "failed reading user bitcode file `" << path
<< "`: " << ec.message();
}
llvm::MemoryBufferRef bitcodeBufferRef(contents, filename);
auto setAlwaysInline = [&](llvm::Module &module) {
if (targetMachine.getTargetCPU().contains("gfx10") ||
targetMachine.getTargetCPU().contains("gfx11")) {
// Some ROCM/HIP functions for gfx10 or gfx11 has accuracy issue if
// inlined.
return;
}
for (auto &func : module.getFunctionList()) {
// Some ROCM/HIP builtin functions have Optnone and NoInline for default.
if (targetMachine.getTargetTriple().isAMDGCN()) {
if (func.hasFnAttribute(llvm::Attribute::OptimizeNone)) {
func.removeFnAttr(llvm::Attribute::OptimizeNone);
}
if (targetMachine.getTargetTriple().isAMDGCN() &&
func.hasFnAttribute(llvm::Attribute::NoInline)) {
func.removeFnAttr(llvm::Attribute::NoInline);
}
}
func.addFnAttr(llvm::Attribute::AlwaysInline);
}
};
if (failed(linkBitcodeModule(
loc, linker, linkerFlags, targetMachine, path,
llvm::parseBitcodeFile(*bitcodeBufferRef->get(), context),
setAlwaysInline))) {
if (failed(
linkBitcodeModule(loc, linker, linkerFlags, targetMachine, filename,
llvm::parseBitcodeFile(bitcodeBufferRef, context),
setAlwaysInline))) {
return mlir::emitError(loc) << "failed linking in user bitcode file `"
<< path << "` for target triple '"
<< filename << "` for target triple '"
<< targetMachine.getTargetTriple().str() << "'";
}

return success();
}

static std::vector<std::string> getUkernelPaths(StringRef enabledUkernelsStr,
StringRef targetChip,
StringRef bitcodePath) {
std::vector<std::string> selectedUkernelNames;
if (enabledUkernelsStr == "all") {
const char *allUkernelNames[] = {"argmax"};
size_t numUkernels = sizeof(allUkernelNames) / sizeof(allUkernelNames[0]);
for (int i = 0; i < numUkernels; i++) {
selectedUkernelNames.push_back(allUkernelNames[i]);
}
} else {
while (!enabledUkernelsStr.empty()) {
auto split = enabledUkernelsStr.split(',');
selectedUkernelNames.push_back(split.first.str());
enabledUkernelsStr = split.second;
}
}

// Construct full path to ROCDL bitcode libraries.
std::vector<std::string> result;
std::string app = "/";
for (auto &kernelName : selectedUkernelNames) {
std::string filename =
"rocm_" + kernelName + "_ukernel_" + targetChip.str();
result.push_back(bitcodePath.str() + app + filename + ".bc");
}
return result;
}

static void overridePlatformGlobal(llvm::Module *module, StringRef globalName,
uint32_t newValue, llvm::Type *globalTy) {
// NOTE: the global will not be defined if it is not used in the module.
Expand Down Expand Up @@ -228,24 +185,36 @@ LogicalResult linkHIPBitcodeIfNeeded(Location loc, llvm::Module *module,
return linkWithBitcodeFiles(loc, module, bitcodePaths);
}

static std::tuple<const iree_file_toc_t *, int>
getUkernelBitcodeTOC(StringRef gpuArch) {
return llvm::StringSwitch<std::tuple<const iree_file_toc_t *, int>>(gpuArch)
.Case("gfx90a",
{iree_uk_amdgpu_gfx90a_create(), iree_uk_amdgpu_gfx90a_size()})
.Case("gfx942",
{iree_uk_amdgpu_gfx942_create(), iree_uk_amdgpu_gfx942_size()})
.Case("gfx1030",
{iree_uk_amdgpu_gfx1030_create(), iree_uk_amdgpu_gfx1030_size()})
.Case("gfx1100",
{iree_uk_amdgpu_gfx1100_create(), iree_uk_amdgpu_gfx1100_size()})
.Default({nullptr, 0});
}

// Links optimized Ukernel bitcode into the given module if the module needs it.
LogicalResult linkUkernelBitcodeFiles(Location loc, llvm::Module *module,
StringRef enabledUkernelsStr,
StringRef targetChip,
StringRef bitcodePath,
unsigned linkerFlags,
llvm::TargetMachine &targetMachine) {
// Early exit if Ukernel not supported on target chip.
if (!iree_compiler::hasUkernelSupportedRocmArch(targetChip)) {
return mlir::emitError(loc)
<< "ukernel '" << enabledUkernelsStr
<< "' not supported on target chip: " << targetChip;
auto [toc, toc_size] = getUkernelBitcodeTOC(targetChip);
if (!toc) {
return failure();
}
std::vector<std::string> ukernelPaths =
getUkernelPaths(enabledUkernelsStr, targetChip, bitcodePath);

llvm::Linker linker(*module);
for (auto &path : ukernelPaths) {
if (failed(linkBitcodeFile(loc, linker, linkerFlags, StringRef(path),
for (int i = 0; i < toc_size; ++i) {
if (failed(linkBitcodeFile(loc, linker, linkerFlags, toc[i].name,
llvm::StringRef(toc[i].data, toc[i].size),
targetMachine, module->getContext())))
return failure();
}
Expand Down
3 changes: 0 additions & 3 deletions compiler/plugins/target/ROCM/ROCMTargetUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ LogicalResult linkUkernelBitcodeFiles(Location loc, llvm::Module *module,
// a blob.
std::string createHsaco(Location loc, StringRef isa, StringRef name);

// Returns true if the rocm archtecture target is supported for ukernels.
bool hasUkernelSupportedRocmArch(IREE::HAL::ExecutableTargetAttr targetAttr);

} // namespace mlir::iree_compiler::IREE::HAL

#endif // IREE_COMPILER_PLUGINS_TARGET_ROCM_ROCMTARGETUTILS_H_
Loading
Loading