Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[METAL] Update metal runtime to directly store kernel map #14727

Merged
merged 2 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/runtime/metal/metal_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ static constexpr const int kMetalMaxNumDevice = 32;
/*!
* \brief create a metal module from data.
*
* \param data The data content.
* \param fmt The format of the data, can be "metal" or "metallib"
* \param smap The map from name to each shader kernel.
* \param fmap The map function information map of each function.
* \param source Optional, source file
* \param fmt The format of the source, can be "metal" or "metallib"
* \param source Optional, source file, concatenaed for debug dump
*/
Module MetalModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source);
Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string fmt,
std::string source);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_METAL_METAL_MODULE_H_
108 changes: 52 additions & 56 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,26 @@
#include "../file_utils.h"
#include "../meta_data.h"
#include "../pack_args.h"
#include "../source_utils.h"
#include "../thread_storage_scope.h"
#include "metal_common.h"

namespace tvm {
namespace runtime {

// The version of metal module
// for future compatibility checking
// bump when we change the binary format.
static constexpr const char* kMetalModuleVersion = "0.1.0";

// Module to support thread-safe multi-GPU execution.
// The runtime will contain a per-device module table
// The modules will be lazily loaded
class MetalModuleNode final : public runtime::ModuleNode {
public:
explicit MetalModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: data_(data), fmt_(fmt), fmap_(fmap), source_(source) {
parsed_kernels_ = SplitKernels(data);
}
explicit MetalModuleNode(std::unordered_map<std::string, std::string> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string fmt,
std::string source)
: smap_(smap), fmap_(fmap), fmt_(fmt), source_(source) {}
const char* type_key() const final { return "metal"; }

/*! \brief Get the property of the runtime module. */
Expand All @@ -57,27 +60,19 @@ int GetPropertyMask() const final {
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;

void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
LOG(FATAL) << "Do not support save to file, use save to binary and export instead";
}

void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
std::string version = kMetalModuleVersion;
stream->Write(version);
stream->Write(smap_);
stream->Write(fmap_);
stream->Write(data_);
stream->Write(fmt_);
}
std::string GetSource(const std::string& format) final {
if (format == fmt_) return data_;
if (source_.length() != 0) {
return source_;
} else if (fmt_ == "metal") {
return data_;
} else {
return "";
}
// return text source if available.
return source_;
}

// get a from primary context in device_id
Expand All @@ -95,15 +90,11 @@ void SaveToBinary(dmlc::Stream* stream) final {
// compile
NSError* err_msg = nil;
id<MTLLibrary> lib = nil;
std::string source;
auto kernel = parsed_kernels_.find(func_name);
// If we cannot find this kernel in parsed_kernels_, it means that all kernels going together
// without explicit separator. In this case we use data_ with all kernels. It done for backward
// compatibility.
if (kernel != parsed_kernels_.end())
source = kernel->second;
else
source = data_;
auto kernel = smap_.find(func_name);
// Directly lookup kernels
ICHECK(kernel != smap_.end());
const std::string& source = kernel->second;

if (fmt_ == "metal") {
MTLCompileOptions* opts = [MTLCompileOptions alloc];
opts.languageVersion = MTLLanguageVersion2_3;
Expand All @@ -115,7 +106,8 @@ void SaveToBinary(dmlc::Stream* stream) final {
error:&err_msg];
[opts dealloc];
if (lib == nil) {
LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String];
LOG(FATAL) << "Fail to compile metal source:"
<< [[err_msg localizedDescription] UTF8String];
}
if (err_msg != nil) {
LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String];
Expand Down Expand Up @@ -161,20 +153,18 @@ void SaveToBinary(dmlc::Stream* stream) final {
}
}
};
// the binary data
std::string data_;
// The format
std::string fmt_;
// the source shader data, can be mtl or binary
std::unordered_map<std::string, std::string> smap_;
// function information table.
std::unordered_map<std::string, FunctionInfo> fmap_;
// The format
std::string fmt_;
// The source
std::string source_;
// function information.
std::vector<DeviceEntry> finfo_;
// internal mutex when updating the module
std::mutex mutex_;
// parsed kernel data
std::unordered_map<std::string, std::string> parsed_kernels_;
};

// a wrapped function class to get packed func.
Expand Down Expand Up @@ -272,39 +262,45 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
return pf;
}

Module MetalModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string fmt,
std::string source) {
ObjectPtr<Object> n;
AUTORELEASEPOOL {
metal::MetalWorkspace::Global()->Init();
n = make_object<MetalModuleNode>(data, fmt, fmap, source);
n = make_object<MetalModuleNode>(smap, fmap, fmt, source);
};
return Module(n);
}

// Load module from module.
Module MetalModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
return MetalModuleCreate(data, fmt, fmap, "");
}
TVM_REGISTER_GLOBAL("runtime.module.create_metal_module")
.set_body_typed([](Map<String, String> smap, std::string fmap_json, std::string fmt,
std::string source) {
std::istringstream stream(fmap_json);
std::unordered_map<std::string, FunctionInfo> fmap;
dmlc::JSONReader reader(&stream);
reader.Read(&fmap);
return MetalModuleCreate(
std::unordered_map<std::string, std::string>(smap.begin(), smap.end()), fmap, fmt,
source);
});

Module MetalModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string data;
// version is reserved for future changes and
// is discarded for now
std::string ver;
std::unordered_map<std::string, std::string> smap;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt;
stream->Read(&fmt);

stream->Read(&ver);
stream->Read(&smap);
stream->Read(&fmap);
stream->Read(&data);
return MetalModuleCreate(data, fmt, fmap, "");
}
stream->Read(&fmt);

TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal").set_body_typed(MetalModuleLoadFile);
return MetalModuleCreate(smap, fmap, fmt, "");
}

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary);
} // namespace runtime
Expand Down
7 changes: 4 additions & 3 deletions src/target/opt/build_metal_off.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
namespace tvm {
namespace runtime {

Module MetalModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string fmt,
std::string source) {
LOG(WARNING) << "Metal runtime not enabled, return a source module...";
return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "metal");
return codegen::DeviceSourceModuleCreate(source, fmt, fmap, "metal");
}

} // namespace runtime
Expand Down
31 changes: 17 additions & 14 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>

#include "../../runtime/metal/metal_module.h"
Expand Down Expand Up @@ -336,33 +337,35 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;

std::stringstream code;
std::stringstream source;
std::string fmt = "metal";
std::ostringstream source_maker;
std::unordered_map<std::string, std::string> smap;
const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile");
std::string fmt = fmetal_compile ? "metallib" : "metal";

for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
code << "// Function: " << kv.first->name_hint << std::endl;
auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined());
std::string func_name = global_symbol.value();

source_maker << "// Function: " << func_name << "\n";
CodeGenMetal cg(target);
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";

cg.AddFunction(f);
std::string fsource = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_metal_compile")) {
source << fsource;
fsource = (*f)(fsource).operator std::string();
fmt = "metallib";
source_maker << fsource << "\n";
if (fmetal_compile) {
fsource = (*fmetal_compile)(fsource).operator std::string();
}
code << fsource;
smap[func_name] = fsource;
}

std::string code_str = code.str();
if (const auto* f = Registry::Get("tvm_callback_metal_postproc")) {
code_str = (*f)(code_str).operator std::string();
}
return MetalModuleCreate(code_str, fmt, ExtractFuncInfo(mod), source.str());
return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str());
}

TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal);
Expand Down