Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

If a cuda launch error occurs, verify if cuda device requires top_k t… #479

Merged
merged 4 commits into from
May 17, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions csrc/backend_ops/tensorrt/batched_nms/allClassNMS.cu
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,19 @@ pluginStatus_t allClassNMS_gpu(cudaStream_t stream, const int num, const int num
(T_BBOX *)bbox_data, (T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array,
(T_SCORE *)afterNMS_scores, (int *)afterNMS_index_array, flipXY);

CSC(cudaGetLastError(), STATUS_FAILURE);
cudaError_t code = cudaGetLastError();
if (code != cudaSuccess) {
// Verify if cuda dev0 requires top_k to be reduced;
// sm_53 (Jetson Nano) and sm_62 (Jetson TX2) requires reduced top_k < 1000
auto __cuda_arch__ = get_cuda_arch(0);
if ((__cuda_arch__ == 530 || __cuda_arch__ == 620) && top_k >= 1000) {
printf(
"Warning: pre_top_k need to be reduced for devices with arch 5.3, 6.2, got "
"pre_top_k=%d\n",
top_k);
}
return STATUS_FAILURE;
}
return STATUS_SUCCESS;
}

Expand Down Expand Up @@ -250,11 +262,6 @@ pluginStatus_t allClassNMS(cudaStream_t stream, const int num, const int num_cla
const bool isNormalized, const DataType DT_SCORE, const DataType DT_BBOX,
void *bbox_data, void *beforeNMS_scores, void *beforeNMS_index_array,
void *afterNMS_scores, void *afterNMS_index_array, bool flipXY) {
auto __cuda_arch__ = get_cuda_arch(0); // assume there is only one arch 7.2 device
if (__cuda_arch__ == 720 && top_k >= 1000) {
printf("Warning: pre_top_k need to be reduced for devices with arch 7.2, got pre_top_k=%d\n",
top_k);
}
nmsLaunchConfigSSD lc = nmsLaunchConfigSSD(DT_SCORE, DT_BBOX, allClassNMS_gpu<float, float>);
for (unsigned i = 0; i < nmsFuncVec.size(); ++i) {
if (lc == nmsFuncVec[i]) {
Expand Down