diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index c86ee73788f9..fa591a4c4321 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -267,10 +267,11 @@ void thrust_scan(DLTensor* data, // This is for constructing a sequence 0, 0, 0,...,1, 1, 1,...,2, 2, 2,..., // without materializing the sequence vector - auto counting_iter = thrust::counting_iterator(0); - auto key_iter = thrust::make_transform_iterator(counting_iter, [scan_size] __device__(int i) { + auto counting_iter = thrust::counting_iterator(0); + auto linear_index_to_scan_key = [scan_size] __host__ __device__(int64_t i) { return i / scan_size; - }); + }; + auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key); int64_t size = 1; for (int i = 0; i < data->ndim; ++i) size *= data->shape[i]; @@ -297,37 +298,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") thrust_scan(data, output, exclusive); } else if (out_dtype == "int64") { thrust_scan(data, output, exclusive); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - // } else if (in_dtype == "int64") { - // if (out_dtype == "int32") { - // thrust_scan(keys_in, values_in, keys_out, values_out, - // for_scatter); - // } else if (out_dtype == "int64") { - // thrust_scan(keys_in, values_in, keys_out, values_out, - // for_scatter); - // } else if (out_dtype == "float32") { - // thrust_scan(keys_in, values_in, keys_out, values_out, - // for_scatter); - // } else { - // LOG(FATAL) << "Unsupported value dtype: " << out_dtype; - // } - // } else if (in_dtype == "float32") { - // if (out_dtype == "int32") { - // thrust_scan(keys_in, values_in, keys_out, values_out, - // for_scatter); - // } else if (out_dtype == "int64") { - // thrust_scan(keys_in, values_in, keys_out, values_out, - // for_scatter); - // } else if (out_dtype == "float32") { - // thrust_scan(keys_in, values_in, keys_out, values_out, - // for_scatter); - // } else { - // LOG(FATAL) << "Unsupported value dtype: " << out_dtype; - // } + } else if (in_dtype == "int64") { + if (out_dtype == "int64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (in_dtype == "float32") { + if (out_dtype == "float32") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } } else { LOG(FATAL) << "Unsupported input dtype: " << in_dtype; }