Skip to content

Commit

Permalink
[METAL] Fix issue with GPU fails
Browse files Browse the repository at this point in the history
Added first run to auto scheduler. This run is necessary for checking
that the generated kernel is correct. When we just run time evaluator
with incorrect kernel then it is possible that our application on iOS
device will be added to ignore list because of big number of committed
incorrect kernels. One run before running auto scheduling helps us to
avoid this problem.

Added complete handlers to all command buffers in Metal runtime. It
helps to handle GPU errors and report about this error to the host
application.

In case when error happened, we have to create a new stream. Added
mechanism for error handling and streams creating from python interface.
  • Loading branch information
echuraev committed Apr 9, 2021
1 parent 461d06e commit 27cd682
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 23 deletions.
18 changes: 16 additions & 2 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,23 @@ def max_thread_dimensions(self):
"""
return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8))

def sync(self):
def create_stream(self):
"""Create a new runtime stream at the context."""
stream = ctypes.c_void_p()
check_call(_LIB.TVMStreamCreate(self.device_type, self.device_id, ctypes.byref(stream)))
return stream

def free_stream(self, stream):
"""Free a created stream handle."""
check_call(_LIB.TVMStreamFree(self.device_type, self.device_id, stream))

def set_stream(self, stream):
"""Set a created stream handle."""
check_call(_LIB.TVMSetStream(self.device_type, self.device_id, stream))

def sync(self, stream=None):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, stream))

def __eq__(self, other):
return (
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,8 @@ def _timed_rpc_run(

if error_no == 0:
try:
stream = dev.create_stream()
dev.set_stream(stream)
random_fill = remote.get_function("tvm.contrib.random.random_fill")
assert (
random_fill
Expand Down Expand Up @@ -1108,6 +1110,11 @@ def _timed_rpc_run(
"task_inputs not fully matched, check if there's any unexpected error"
)
dev.sync()

# First run for check that the kernel is correct
func.entry_func(*args)
dev.sync()

costs = time_f(*args).results

# clean up remote files
Expand All @@ -1119,6 +1126,8 @@ def _timed_rpc_run(
costs = (MAX_FLOAT,)
error_no = MeasureErrorNo.RUNTIME_DEVICE
error_msg = make_traceback_info()
finally:
dev.free_stream(stream)

shutil.rmtree(os.path.dirname(build_res.filename))
toc = time.time()
Expand Down
43 changes: 34 additions & 9 deletions src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,37 @@
namespace tvm {
namespace runtime {
namespace metal {
/*!
* \brief Structure for error handling in queues
*/
class Stream {
public:
explicit Stream(id<MTLDevice> device) : error_happened_(false) { queue_ = [device newCommandQueue]; }
~Stream() { [queue_ release]; }
id<MTLCommandBuffer> GetCommandBuffer() {
id<MTLCommandBuffer> cb = [queue_ commandBuffer];
[cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
if (buffer.status == MTLCommandBufferStatusError) SetErrorStatus();
}];
return cb;
}
bool IsErrorHappened() { return error_happened_; }

private:
void SetErrorStatus() { error_happened_ = true; }
// Queue
id<MTLCommandQueue> queue_;
// Check if error happened in one previous run
bool error_happened_;
};

/*!
* \brief Process global Metal workspace.
*/
class MetalWorkspace final : public DeviceAPI {
public:
// the devices
std::vector<id<MTLDevice> > devices;
// the queues
std::vector<id<MTLCommandQueue> > queues;
// Warp size constant
std::vector<int> warp_size;
// Whether it is initialized.
Expand All @@ -62,13 +84,6 @@ class MetalWorkspace final : public DeviceAPI {
std::mutex mutex;
// Destructor
~MetalWorkspace();
// Get command queue for given device.
id<MTLCommandQueue> GetCommandQueue(Device dev) {
ICHECK_EQ(dev.device_type, kDLMetal);
ICHECK(dev.device_id >= 0 && static_cast<size_t>(dev.device_id) < queues.size())
<< "Invalid Metal device_id=" << dev.device_id;
return queues[dev.device_id];
}
// Get device for given device
id<MTLDevice> GetDevice(Device dev) {
ICHECK_EQ(dev.device_type, kDLMetal);
Expand All @@ -84,23 +99,33 @@ class MetalWorkspace final : public DeviceAPI {
void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final;
void FreeDataSpace(Device dev, void* ptr) final;
TVMStreamHandle CreateStream(Device dev) final;
void FreeStream(Device dev, TVMStreamHandle stream) final;
void StreamSync(Device dev, TVMStreamHandle stream) final;
void SetStream(Device dev, TVMStreamHandle stream) final;
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;

// get the global workspace
static MetalWorkspace* Global();

protected:
void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size,
Device dev_from, Device dev_to, DLDataType type_hint,
TVMStreamHandle stream) final;

private:
// Pointers to default allocated streams
std::vector<Stream*> default_streams_;
};

/*! \brief Thread local workspace */
class MetalThreadEntry {
public:
/*! \brief The current device */
Device device;
/*! \brief The current stream */
std::vector<Stream*> stream;
/*! \brief The shared buffer used for copy. */
std::vector<id<MTLBuffer> > temp_buffer_;
/*! \brief workspace pool */
Expand Down
50 changes: 40 additions & 10 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ int GetWarpSize(id<MTLDevice> dev) {
for (auto x : devices) {
[x release];
}
for (auto x : queues) {
[x release];
for (auto x : default_streams_) {
delete x;
}
}

Expand All @@ -136,13 +136,17 @@ int GetWarpSize(id<MTLDevice> dev) {
// on iPhone
id<MTLDevice> d = MTLCreateSystemDefaultDevice();
devices.push_back(d);
queues.push_back([d newCommandQueue]);
Stream* stream = new Stream(d);
MetalThreadEntry::ThreadLocal()->stream.push_back(stream);
default_streams_.push_back(stream);
#else
NSArray<id<MTLDevice> >* devs = MTLCopyAllDevices();
for (size_t i = 0; i < devs.count; ++i) {
id<MTLDevice> d = [devs objectAtIndex:i];
devices.push_back(d);
queues.push_back([d newCommandQueue]);
Stream* stream = new Stream(d);
MetalThreadEntry::ThreadLocal()->stream.push_back(stream);
default_streams_.push_back(stream);
LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String];
warp_size.push_back(GetWarpSize(d));
}
Expand Down Expand Up @@ -183,16 +187,25 @@ int GetWarpSize(id<MTLDevice> dev) {
}
}

Stream* getStream(TVMStreamHandle stream, int device_id) {
if (stream != nullptr)
return static_cast<Stream*>(stream);
else
return MetalThreadEntry::ThreadLocal()->stream[device_id];
}

void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t to_offset, size_t size, Device dev_from, Device dev_to,
DLDataType type_hint, TVMStreamHandle stream) {
@autoreleasepool {
this->Init();
ICHECK(stream == nullptr);
Device dev = dev_from;
Stream* s = getStream(stream, dev.device_id);
if (s->IsErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream";
}
if (dev_from.device_type == kDLCPU) dev = dev_to;
id<MTLCommandQueue> queue = GetCommandQueue(dev);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
int from_dev_type = static_cast<int>(dev_from.device_type);
int to_dev_type = static_cast<int>(dev_to.device_type);

Expand Down Expand Up @@ -249,17 +262,34 @@ int GetWarpSize(id<MTLDevice> dev) {
}
}

TVMStreamHandle MetalWorkspace::CreateStream(Device dev) {
Stream* stream = new Stream(devices[dev.device_id]);
return static_cast<TVMStreamHandle>(stream);
}

void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) {
ICHECK(stream != nullptr);
Stream* s = static_cast<Stream*>(stream);
delete s;
}

void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
@autoreleasepool {
ICHECK(stream == nullptr);
Stream* s = getStream(stream, dev.device_id);
// commit an empty command buffer and wait until it completes.
id<MTLCommandQueue> queue = GetCommandQueue(dev);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
[cb commit];
[cb waitUntilCompleted];
if (s->IsErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned!";
}
}
}

void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast<Stream*>(stream);
}

void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
Expand Down
5 changes: 3 additions & 2 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,16 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
@autoreleasepool {
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int device_id = t->device.device_id;
auto stream = static_cast<metal::Stream*>(t->stream[device_id]);
if (stream->IsErrorHappened()) return;
if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
}
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup;
CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
id<MTLCommandQueue> queue = w_->GetCommandQueue(t->device);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLCommandBuffer> cb = stream->GetCommandBuffer();
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
[encoder setComputePipelineState:scache_[device_id]];
for (size_t i = 0; i < num_buffer_args_; ++i) {
Expand Down
62 changes: 62 additions & 0 deletions src/runtime/minrpc/minrpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,22 @@ class MinRPCServer {
this->SyscallDevFreeData(values, tcodes, num_args);
break;
}
case RPCCode::kDevCreateStream: {
this->SyscallDevCreateStream(values, tcodes, num_args);
break;
}
case RPCCode::kDevFreeStream: {
this->SyscallDevFreeStream(values, tcodes, num_args);
break;
}
case RPCCode::kDevStreamSync: {
this->SyscallDevStreamSync(values, tcodes, num_args);
break;
}
case RPCCode::kDevSetStream: {
this->SyscallDevSetStream(values, tcodes, num_args);
break;
}
case RPCCode::kCopyAmongRemote: {
this->SyscallCopyAmongRemote(values, tcodes, num_args);
break;
Expand Down Expand Up @@ -444,6 +456,39 @@ class MinRPCServer {
}
}

void SyscallDevCreateStream(TVMValue* values, int* tcodes, int num_args) {
MINRPC_CHECK(num_args == 1);
MINRPC_CHECK(tcodes[0] == kDLDevice);

DLDevice dev = values[0].v_device;
void* handle;

int call_ecode = TVMStreamCreate(dev.device_type, dev.device_id, &handle);

if (call_ecode == 0) {
this->ReturnHandle(handle);
} else {
this->ReturnLastTVMError();
}
}

void SyscallDevFreeStream(TVMValue* values, int* tcodes, int num_args) {
MINRPC_CHECK(num_args == 2);
MINRPC_CHECK(tcodes[0] == kDLDevice);
MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);

DLDevice dev = values[0].v_device;
void* handle = values[1].v_handle;

int call_ecode = TVMStreamFree(dev.device_type, dev.device_id, handle);

if (call_ecode == 0) {
this->ReturnVoid();
} else {
this->ReturnLastTVMError();
}
}

void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) {
MINRPC_CHECK(num_args == 2);
MINRPC_CHECK(tcodes[0] == kDLDevice);
Expand All @@ -461,6 +506,23 @@ class MinRPCServer {
}
}

void SyscallDevSetStream(TVMValue* values, int* tcodes, int num_args) {
MINRPC_CHECK(num_args == 2);
MINRPC_CHECK(tcodes[0] == kDLDevice);
MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);

DLDevice dev = values[0].v_device;
void* handle = values[1].v_handle;

int call_ecode = TVMSetStream(dev.device_type, dev.device_id, handle);

if (call_ecode == 0) {
this->ReturnVoid();
} else {
this->ReturnLastTVMError();
}
}

void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
io_->Exit(static_cast<int>(code));
}
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/minrpc/rpc_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ enum class RPCCode : int {
kDevGetAttr,
kDevAllocData,
kDevFreeData,
kDevCreateStream,
kDevFreeStream,
kDevStreamSync,
kDevSetStream,
kCopyAmongRemote,
kDevAllocDataWithScope,
};
Expand Down Expand Up @@ -104,8 +107,14 @@ inline const char* RPCCodeToString(RPCCode code) {
return "kDevAllocData";
case RPCCode::kDevFreeData:
return "kDevFreeData";
case RPCCode::kDevCreateStream:
return "kDevCreateStream";
case RPCCode::kDevFreeStream:
return "kDevFreeStream";
case RPCCode::kDevStreamSync:
return "kDevStreamSync";
case RPCCode::kDevSetStream:
return "kDevSetStream";
case RPCCode::kCopyAmongRemote:
return "kCopyAmongRemote";
case RPCCode::kDevAllocDataWithScope:
Expand Down
Loading

0 comments on commit 27cd682

Please sign in to comment.