Skip to content

Commit

Permalink
[RUNTIME] Better error message in cuda launch (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Oct 5, 2017
1 parent 9435972 commit 7d42f9f
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions src/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,33 @@ class CUDAWrappedFunc {
}
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
CUDA_DRIVER_CALL(cuLaunchKernel(
CUresult result = cuLaunchKernel(
fcache_[device_id],
wl.grid_dim(0),
wl.grid_dim(1),
wl.grid_dim(2),
wl.block_dim(0),
wl.block_dim(1),
wl.block_dim(2),
0, strm, void_args, 0));
0, strm, void_args, 0);
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) {
const char *msg;
cuGetErrorName(result, &msg);
std::ostringstream os;
os << "CUDALaunch Error: " << msg << "\n"
<< " grid=(" << wl.grid_dim(0) << ","
<< wl.grid_dim(1) << "," << wl.grid_dim(2) << "), "
<< " block=(" << wl.block_dim(0) << ","
<< wl.block_dim(1) << "," << wl.block_dim(2) << ")\n";
std::string cuda = m_->GetSource("");
if (cuda.length() != 0) {
os << "// func_name=" << func_name_ << "\n"
<< "// CUDA Source\n"
<< "// -----------\n"
<< cuda;
}
LOG(FATAL) << os.str();
}
}

private:
Expand Down

0 comments on commit 7d42f9f

Please sign in to comment.