Skip to content

Commit

Permalink
support more data type
Browse files Browse the repository at this point in the history
  • Loading branch information
masa authored and masahi committed Dec 24, 2020
1 parent 3e7d1f8 commit abceac9
Showing 1 changed file with 16 additions and 31 deletions.
47 changes: 16 additions & 31 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,11 @@ void thrust_scan(DLTensor* data,

// This is for constructing a sequence 0, 0, 0,...,1, 1, 1,...,2, 2, 2,...,
// without materializing the sequence vector
auto counting_iter = thrust::counting_iterator<int>(0);
auto key_iter = thrust::make_transform_iterator(counting_iter, [scan_size] __device__(int i) {
auto counting_iter = thrust::counting_iterator<int64_t>(0);
auto linear_index_to_scan_key = [scan_size] __host__ __device__(int64_t i) {
return i / scan_size;
});
};
auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key);
int64_t size = 1;
for (int i = 0; i < data->ndim; ++i) size *= data->shape[i];

Expand All @@ -297,37 +298,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan")
thrust_scan<int, int>(data, output, exclusive);
} else if (out_dtype == "int64") {
thrust_scan<int, int64_t>(data, output, exclusive);
} else if (out_dtype == "float32") {
thrust_scan<int, float>(data, output, exclusive);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
// } else if (in_dtype == "int64") {
// if (out_dtype == "int32") {
// thrust_scan<int64_t, int>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else if (out_dtype == "int64") {
// thrust_scan<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else if (out_dtype == "float32") {
// thrust_scan<int64_t, float>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else {
// LOG(FATAL) << "Unsupported value dtype: " << out_dtype;
// }
// } else if (in_dtype == "float32") {
// if (out_dtype == "int32") {
// thrust_scan<float, int>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else if (out_dtype == "int64") {
// thrust_scan<float, int64_t>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else if (out_dtype == "float32") {
// thrust_scan<float, float>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else {
// LOG(FATAL) << "Unsupported value dtype: " << out_dtype;
// }
} else if (in_dtype == "int64") {
if (out_dtype == "int64") {
thrust_scan<int64_t, int64_t>(data, output, exclusive);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (in_dtype == "float32") {
if (out_dtype == "float32") {
thrust_scan<float, float>(data, output, exclusive);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << in_dtype;
}
Expand Down

0 comments on commit abceac9

Please sign in to comment.