diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 59dc652aeb0bc..6192fa984f1b6 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -262,9 +262,23 @@ def max_thread_dimensions(self): """ return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8)) - def sync(self): + def create_stream(self): + """Create a new runtime stream at the context.""" + stream = ctypes.c_void_p() + check_call(_LIB.TVMStreamCreate(self.device_type, self.device_id, ctypes.byref(stream))) + return stream + + def free_stream(self, stream): + """Free a created stream handle.""" + check_call(_LIB.TVMStreamFree(self.device_type, self.device_id, stream)) + + def set_stream(self, stream): + """Set a created stream handle.""" + check_call(_LIB.TVMSetStream(self.device_type, self.device_id, stream)) + + def sync(self, stream=None): """Synchronize until jobs finished at the context.""" - check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None)) + check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, stream)) def __eq__(self, other): return ( diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 83f1bcec7ebcf..94d2b5be94797 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -1076,6 +1076,8 @@ def _timed_rpc_run( if error_no == 0: try: + stream = dev.create_stream() + dev.set_stream(stream) random_fill = remote.get_function("tvm.contrib.random.random_fill") assert ( random_fill @@ -1108,6 +1110,11 @@ def _timed_rpc_run( "task_inputs not fully matched, check if there's any unexpected error" ) dev.sync() + + # First run for check that the kernel is correct + func.entry_func(*args) + dev.sync() + costs = time_f(*args).results # clean up remote files @@ -1119,6 +1126,8 @@ def _timed_rpc_run( costs = (MAX_FLOAT,) error_no = MeasureErrorNo.RUNTIME_DEVICE error_msg = make_traceback_info() + finally: + dev.free_stream(stream) shutil.rmtree(os.path.dirname(build_res.filename)) toc = time.time() diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 55f9022a6b96d..f67f41cad9fa9 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -45,6 +45,30 @@ namespace tvm { namespace runtime { namespace metal { +/*! + * \brief Structure for error handling in queues + */ +class Stream { + public: + Stream(id device) : error_happened_(false) { queue_ = [device newCommandQueue]; } + ~Stream() { [queue_ release]; } + id GetCommandBuffer() { + id cb = [queue_ commandBuffer]; + [cb addCompletedHandler:^(id buffer) { + if (buffer.status == MTLCommandBufferStatusError) SetErrorStatus(); + }]; + return cb; + } + bool IsErrorHappened() { return error_happened_; } + + private: + void SetErrorStatus() { error_happened_ = true; } + // Queue + id queue_; + // Check if error happened in one previous run + bool error_happened_; +}; + /*! * \brief Process global Metal workspace. */ @@ -52,8 +76,6 @@ class MetalWorkspace final : public DeviceAPI { public: // the devices std::vector > devices; - // the queues - std::vector > queues; // Warp size constant std::vector warp_size; // Whether it is initialized. @@ -62,13 +84,6 @@ class MetalWorkspace final : public DeviceAPI { std::mutex mutex; // Destructor ~MetalWorkspace(); - // Get command queue for given device. - id GetCommandQueue(Device dev) { - ICHECK_EQ(dev.device_type, kDLMetal); - ICHECK(dev.device_id >= 0 && static_cast(dev.device_id) < queues.size()) - << "Invalid Metal device_id=" << dev.device_id; - return queues[dev.device_id]; - } // Get device for given device id GetDevice(Device dev) { ICHECK_EQ(dev.device_type, kDLMetal); @@ -84,9 +99,13 @@ class MetalWorkspace final : public DeviceAPI { void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(Device dev, void* ptr) final; + TVMStreamHandle CreateStream(Device dev) final; + void FreeStream(Device dev, TVMStreamHandle stream) final; void StreamSync(Device dev, TVMStreamHandle stream) final; + void SetStream(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; + // get the global workspace static MetalWorkspace* Global(); @@ -94,6 +113,10 @@ class MetalWorkspace final : public DeviceAPI { void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) final; + + private: + // Pointers to default allocated streams + std::vector default_streams_; }; /*! \brief Thread local workspace */ @@ -101,6 +124,8 @@ class MetalThreadEntry { public: /*! \brief The current device */ Device device; + /*! \brief The current stream */ + std::vector stream; /*! \brief The shared buffer used for copy. */ std::vector > temp_buffer_; /*! \brief workspace pool */ diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index cf8520864e99e..03d91a2d365e9 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -121,8 +121,8 @@ int GetWarpSize(id dev) { for (auto x : devices) { [x release]; } - for (auto x : queues) { - [x release]; + for (auto x : default_streams_) { + delete x; } } @@ -136,13 +136,17 @@ int GetWarpSize(id dev) { // on iPhone id d = MTLCreateSystemDefaultDevice(); devices.push_back(d); - queues.push_back([d newCommandQueue]); + Stream* stream = new Stream(d); + MetalThreadEntry::ThreadLocal()->stream.push_back(stream); + default_streams_.push_back(stream); #else NSArray >* devs = MTLCopyAllDevices(); for (size_t i = 0; i < devs.count; ++i) { id d = [devs objectAtIndex:i]; devices.push_back(d); - queues.push_back([d newCommandQueue]); + Stream* stream = new Stream(d); + MetalThreadEntry::ThreadLocal()->stream.push_back(stream); + default_streams_.push_back(stream); LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String]; warp_size.push_back(GetWarpSize(d)); } @@ -183,16 +187,25 @@ int GetWarpSize(id dev) { } } +Stream* getStream(TVMStreamHandle stream, int device_id) { + if (stream != nullptr) + return static_cast(stream); + else + return MetalThreadEntry::ThreadLocal()->stream[device_id]; +} + void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) { @autoreleasepool { this->Init(); - ICHECK(stream == nullptr); Device dev = dev_from; + Stream* s = getStream(stream, dev.device_id); + if (s->IsErrorHappened()) { + LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream"; + } if (dev_from.device_type == kDLCPU) dev = dev_to; - id queue = GetCommandQueue(dev); - id cb = [queue commandBuffer]; + id cb = s->GetCommandBuffer(); int from_dev_type = static_cast(dev_from.device_type); int to_dev_type = static_cast(dev_to.device_type); @@ -249,17 +262,34 @@ int GetWarpSize(id dev) { } } +TVMStreamHandle MetalWorkspace::CreateStream(Device dev) { + Stream* stream = new Stream(devices[dev.device_id]); + return static_cast(stream); +} + +void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) { + ICHECK(stream != nullptr); + Stream* s = static_cast(stream); + delete s; +} + void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { @autoreleasepool { - ICHECK(stream == nullptr); + Stream* s = getStream(stream, dev.device_id); // commit an empty command buffer and wait until it completes. - id queue = GetCommandQueue(dev); - id cb = [queue commandBuffer]; + id cb = s->GetCommandBuffer(); [cb commit]; [cb waitUntilCompleted]; + if (s->IsErrorHappened()) { + LOG(FATAL) << "Error! Some problems on GPU happaned!"; + } } } +void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) { + MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast(stream); +} + void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size); } diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index a8b01815bf688..29d726a0ee97a 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -185,6 +185,8 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons @autoreleasepool { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->device.device_id; + auto stream = static_cast(t->stream[device_id]); + if (stream->IsErrorHappened()) return; if (scache_[device_id] == nil) { scache_[device_id] = m_->GetPipelineState(device_id, func_name_); } @@ -192,8 +194,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2); auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup; CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup); - id queue = w_->GetCommandQueue(t->device); - id cb = [queue commandBuffer]; + id cb = stream->GetCommandBuffer(); id encoder = [cb computeCommandEncoder]; [encoder setComputePipelineState:scache_[device_id]]; for (size_t i = 0; i < num_buffer_args_; ++i) { diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index 732e1e49d4a40..1dfee70a20e2d 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -297,10 +297,22 @@ class MinRPCServer { this->SyscallDevFreeData(values, tcodes, num_args); break; } + case RPCCode::kDevCreateStream: { + this->SyscallDevCreateStream(values, tcodes, num_args); + break; + } + case RPCCode::kDevFreeStream: { + this->SyscallDevFreeStream(values, tcodes, num_args); + break; + } case RPCCode::kDevStreamSync: { this->SyscallDevStreamSync(values, tcodes, num_args); break; } + case RPCCode::kDevSetStream: { + this->SyscallDevSetStream(values, tcodes, num_args); + break; + } case RPCCode::kCopyAmongRemote: { this->SyscallCopyAmongRemote(values, tcodes, num_args); break; @@ -444,6 +456,39 @@ class MinRPCServer { } } + void SyscallDevCreateStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 1); + MINRPC_CHECK(tcodes[0] == kDLDevice); + + DLDevice dev = values[0].v_device; + void* handle; + + int call_ecode = TVMStreamCreate(dev.device_type, dev.device_id, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevFreeStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kDLDevice); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + DLDevice dev = values[0].v_device; + void* handle = values[1].v_handle; + + int call_ecode = TVMStreamFree(dev.device_type, dev.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) { MINRPC_CHECK(num_args == 2); MINRPC_CHECK(tcodes[0] == kDLDevice); @@ -461,6 +506,23 @@ class MinRPCServer { } } + void SyscallDevSetStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kDLDevice); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + DLDevice dev = values[0].v_device; + void* handle = values[1].v_handle; + + int call_ecode = TVMSetStream(dev.device_type, dev.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { io_->Exit(static_cast(code)); } diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index e42508a739596..494b87e54b608 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -49,7 +49,10 @@ enum class RPCCode : int { kDevGetAttr, kDevAllocData, kDevFreeData, + kDevCreateStream, + kDevFreeStream, kDevStreamSync, + kDevSetStream, kCopyAmongRemote, kDevAllocDataWithScope, }; @@ -104,8 +107,14 @@ inline const char* RPCCodeToString(RPCCode code) { return "kDevAllocData"; case RPCCode::kDevFreeData: return "kDevFreeData"; + case RPCCode::kDevCreateStream: + return "kDevCreateStream"; + case RPCCode::kDevFreeStream: + return "kDevFreeStream"; case RPCCode::kDevStreamSync: return "kDevStreamSync"; + case RPCCode::kDevSetStream: + return "kDevSetStream"; case RPCCode::kCopyAmongRemote: return "kCopyAmongRemote"; case RPCCode::kDevAllocDataWithScope: diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 1d6fb85d94955..a2d1ac17ef7f9 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -111,11 +111,26 @@ class RPCDeviceAPI final : public DeviceAPI { } } + TVMStreamHandle CreateStream(Device dev) { + auto remote_dev = RemoveRPCSessionMask(dev); + return GetSess(dev)->GetDeviceAPI(remote_dev)->CreateStream(remote_dev); + } + + void FreeStream(Device dev, TVMStreamHandle stream) { + auto remote_dev = RemoveRPCSessionMask(dev); + GetSess(dev)->GetDeviceAPI(remote_dev)->FreeStream(remote_dev, stream); + } + void StreamSync(Device dev, TVMStreamHandle stream) final { auto remote_dev = RemoveRPCSessionMask(dev); GetSess(dev)->GetDeviceAPI(remote_dev)->StreamSync(remote_dev, stream); } + void SetStream(Device dev, TVMStreamHandle stream) { + auto remote_dev = RemoveRPCSessionMask(dev); + GetSess(dev)->GetDeviceAPI(remote_dev)->SetStream(remote_dev, stream); + } + protected: void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t num_bytes, Device dev_from, Device dev_to, DLDataType type_hint, diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index b5768146b3f76..ba33b5065ebb9 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -920,6 +920,24 @@ void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { handler->GetDeviceAPI(dev)->CopyDataFromTo(from, to, stream); } +void RPCDevCreateStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + void* data = handler->GetDeviceAPI(dev)->CreateStream(dev); + *rv = data; +} + +void RPCDevFreeStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + TVMStreamHandle stream = args[1]; + handler->GetDeviceAPI(dev)->FreeStream(dev, stream); +} + +void RPCDevSetStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + TVMStreamHandle stream = args[1]; + handler->GetDeviceAPI(dev)->SetStream(dev, stream); +} + void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { // Event handler sit at clean state at this point. switch (code) { @@ -945,9 +963,18 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break; + case RPCCode::kDevCreateStream: + SysCallHandler(RPCDevCreateStream); + break; + case RPCCode::kDevFreeStream: + SysCallHandler(RPCDevFreeStream); + break; case RPCCode::kDevStreamSync: this->HandleSyscallStreamSync(); break; + case RPCCode::kDevSetStream: + SysCallHandler(RPCDevSetStream); + break; case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break; @@ -1033,10 +1060,22 @@ class RPCClientSession : public RPCSession, public DeviceAPI { endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, from, to, stream); } + TVMStreamHandle CreateStream(Device dev) final { + return endpoint_->SysCallRemote(RPCCode::kDevCreateStream, dev); + } + + void FreeStream(Device dev, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevFreeStream, dev, stream); + } + void StreamSync(Device dev, TVMStreamHandle stream) final { endpoint_->SysCallRemote(RPCCode::kDevStreamSync, dev, stream); } + void SetStream(Device dev, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevSetStream, dev, stream); + } + DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing) final { return this; } bool IsLocalSession() const final { return false; }