diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index fa591a4c4321..3c56cd0c563b 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -268,9 +268,10 @@ void thrust_scan(DLTensor* data, // This is for constructing a sequence 0, 0, 0,...,1, 1, 1,...,2, 2, 2,..., // without materializing the sequence vector auto counting_iter = thrust::counting_iterator(0); + // Without __host__ annotation, cub crashes auto linear_index_to_scan_key = [scan_size] __host__ __device__(int64_t i) { return i / scan_size; - }; + }; // NOLINT(*) auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key); int64_t size = 1; for (int i = 0; i < data->ndim; ++i) size *= data->shape[i];