Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Thrust] Use no sync exec policy and caching allocator #16386

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 123 additions & 122 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,48 +21,55 @@
* \file Use external Thrust library call
*/

#include <dlpack/dlpack.h>
#include <thrust/detail/caching_allocator.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/sort.h>
#include <thrust/gather.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>

#include <thrust/sort.h>
#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>

#include <algorithm>
#include <vector>
#include <functional>
#include <vector>

#include "../../cuda/cuda_common.h"

namespace tvm {
namespace contrib {

using namespace runtime;

auto get_thrust_exec_policy() {
return thrust::cuda::par_nosync(thrust::detail::single_device_tls_caching_allocator())
.on(GetCUDAStream());
}

// Performs sorting along axis -1 and returns both sorted values and indices.
template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
bool is_ascend,
template <typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, bool is_ascend,
int n_values) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));
thrust::device_ptr<DataType> data_ptr(static_cast<DataType*>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType*>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType*>(out_indices->data));

auto policy = get_thrust_exec_policy();

size_t size = 1;
for (int i = 0; i < input->ndim; ++i) {
size *= input->shape[i];
}
thrust::copy(data_ptr, data_ptr + size, values_ptr);
thrust::copy(policy, data_ptr, data_ptr + size, values_ptr);

if (size == static_cast<size_t>(input->shape[input->ndim - 1])) {
// A fast path for single segment case
thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
thrust::sort_by_key(policy, values_ptr, values_ptr + n_values, indices_ptr);
} else {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
thrust::sort_by_key(policy, values_ptr, values_ptr + n_values, indices_ptr,
thrust::greater<DataType>());
}
} else {
Expand All @@ -74,9 +81,9 @@ void thrust_sort(DLTensor* input,

// First, sort values and store the sorted order in argsort_order.
if (is_ascend) {
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin());
thrust::stable_sort_by_key(policy, values_ptr, values_ptr + size, argsort_order.begin());
} else {
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin(),
thrust::stable_sort_by_key(policy, values_ptr, values_ptr + size, argsort_order.begin(),
thrust::greater<DataType>());
}

Expand All @@ -85,36 +92,33 @@ void thrust_sort(DLTensor* input,
auto counting_iter = thrust::counting_iterator<int64_t>(0);
auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) {
return i % n_values;
}; // NOLINT(*)
auto init_indices_iter = thrust::make_transform_iterator(counting_iter,
linear_index_to_sort_axis_index);
}; // NOLINT(*)
auto init_indices_iter =
thrust::make_transform_iterator(counting_iter, linear_index_to_sort_axis_index);

// This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr
thrust::gather(argsort_order.begin(), argsort_order.end(), init_indices_iter, indices_ptr);
thrust::gather(policy, argsort_order.begin(), argsort_order.end(), init_indices_iter,
indices_ptr);

thrust::device_vector<int> segment_ids(size);
auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) {
return i / n_values;
}; // NOLINT(*)
}; // NOLINT(*)
// We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr
thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(),
thrust::transform(policy, argsort_order.begin(), argsort_order.end(), segment_ids.begin(),
linear_index_to_segment_id);

// The second sort key-ed by segment_ids would bring segment_ids back to 0, 0, 0, 1, 1, 1 ...
// values_ptr and indices_ptr will also be sorted in the order of segmend_ids above
// Since sorting has been done in a stable way, relative orderings of values and indices
// in the segment do not change and hence they remain sorted.
auto key_val_zip = thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr));
thrust::stable_sort_by_key(segment_ids.begin(), segment_ids.end(), key_val_zip);
thrust::stable_sort_by_key(policy, segment_ids.begin(), segment_ids.end(), key_val_zip);
}
}

void thrust_sort_common(DLTensor* input,
DLTensor* values_out,
DLTensor* indices_out,
bool is_ascend,
int sort_len,
std::string data_dtype,
void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices_out,
bool is_ascend, int sort_len, std::string data_dtype,
std::string out_dtype) {
if (data_dtype == "float32") {
if (out_dtype == "int32") {
Expand Down Expand Up @@ -152,7 +156,7 @@ void thrust_sort_common(DLTensor* input,
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
Expand All @@ -169,8 +173,7 @@ void thrust_sort_common(DLTensor* input,
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort").set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 4);
DLTensor* input = args[0];
DLTensor* values_out = args[1];
Expand All @@ -181,97 +184,94 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
auto out_dtype = DLDataType2String(indices_out->dtype);

int n_values = input->shape[input->ndim - 1];
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values,
data_dtype, out_dtype);
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype);
});

template<typename KeyType, typename ValueType>
void thrust_stable_sort_by_key(DLTensor* keys_in,
DLTensor* values_in,
DLTensor* keys_out,
DLTensor* values_out,
bool for_scatter) {
template <typename KeyType, typename ValueType>
void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out,
DLTensor* values_out, bool for_scatter) {
const auto size = keys_in->shape[0];
thrust::device_ptr<KeyType> keys_in_ptr(static_cast<KeyType *>(keys_in->data));
thrust::device_ptr<ValueType> values_in_ptr(static_cast<ValueType *>(values_in->data));
thrust::device_ptr<KeyType> keys_out_ptr(static_cast<KeyType *>(keys_out->data));
thrust::device_ptr<ValueType> values_out_ptr(static_cast<ValueType *>(values_out->data));
thrust::device_ptr<KeyType> keys_in_ptr(static_cast<KeyType*>(keys_in->data));
thrust::device_ptr<ValueType> values_in_ptr(static_cast<ValueType*>(values_in->data));
thrust::device_ptr<KeyType> keys_out_ptr(static_cast<KeyType*>(keys_out->data));
thrust::device_ptr<ValueType> values_out_ptr(static_cast<ValueType*>(values_out->data));

auto policy = get_thrust_exec_policy();

if (for_scatter) {
thrust::transform(keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__(KeyType k) {
if (k < 0) return k + static_cast<KeyType>(size);
return k;
});
thrust::transform(policy, keys_in_ptr, keys_in_ptr + size, keys_out_ptr,
[size] __device__(KeyType k) {
if (k < 0) return k + static_cast<KeyType>(size);
return k;
});
} else {
thrust::copy(keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
thrust::copy(policy, keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
}
thrust::copy(values_in_ptr, values_in_ptr + size, values_out_ptr);
thrust::copy(policy, values_in_ptr, values_in_ptr + size, values_out_ptr);

thrust::stable_sort_by_key(keys_out_ptr, keys_out_ptr + size, values_out_ptr);
thrust::stable_sort_by_key(policy, keys_out_ptr, keys_out_ptr + size, values_out_ptr);
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 5);
DLTensor* keys_in = args[0];
DLTensor* values_in = args[1];
DLTensor* keys_out = args[2];
DLTensor* values_out = args[3];
bool for_scatter = args[4];

auto key_dtype = DLDataType2String(keys_in->dtype);
auto value_dtype = DLDataType2String(values_in->dtype);

if (key_dtype == "int32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, keys_out, values_out,
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 5);
DLTensor* keys_in = args[0];
DLTensor* values_in = args[1];
DLTensor* keys_out = args[2];
DLTensor* values_out = args[3];
bool for_scatter = args[4];

auto key_dtype = DLDataType2String(keys_in->dtype);
auto value_dtype = DLDataType2String(values_in->dtype);

if (key_dtype == "int32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "int64") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "float32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else {
LOG(FATAL) << "Unsupported key dtype: " << key_dtype;
}
});
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "int64") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "float32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else {
LOG(FATAL) << "Unsupported key dtype: " << key_dtype;
}
});

template<typename InType, typename OutType>
void thrust_scan(DLTensor* data,
DLTensor* output,
bool exclusive) {
thrust::device_ptr<InType> data_ptr(static_cast<InType *>(data->data));
thrust::device_ptr<OutType> output_ptr(static_cast<OutType *>(output->data));
template <typename InType, typename OutType>
void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive) {
thrust::device_ptr<InType> data_ptr(static_cast<InType*>(data->data));
thrust::device_ptr<OutType> output_ptr(static_cast<OutType*>(output->data));
const auto scan_size = data->shape[data->ndim - 1];

if (scan_size == 0) return;
Expand All @@ -281,19 +281,20 @@ void thrust_scan(DLTensor* data,

const bool need_cast = std::is_same<InType, OutType>::value == false;

auto data_cast_ptr = thrust::make_transform_iterator(data_ptr, [] __host__ __device__(InType v) {
return static_cast<OutType>(v);
}); // NOLINT(*)
auto data_cast_ptr = thrust::make_transform_iterator(
data_ptr, [] __host__ __device__(InType v) { return static_cast<OutType>(v); }); // NOLINT(*)

auto policy = get_thrust_exec_policy();

if (size == static_cast<size_t>(data->shape[data->ndim - 1])) {
if (exclusive && need_cast) {
thrust::exclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
thrust::exclusive_scan(policy, data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
} else if (exclusive && !need_cast) {
thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
thrust::exclusive_scan(policy, data_ptr, data_ptr + scan_size, output_ptr);
} else if (!exclusive && need_cast) {
thrust::inclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
thrust::inclusive_scan(policy, data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
} else {
thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
thrust::inclusive_scan(policy, data_ptr, data_ptr + scan_size, output_ptr);
}
} else {
// Use thrust segmented scan to compute scan on the inner most axis
Expand All @@ -305,18 +306,18 @@ void thrust_scan(DLTensor* data,
auto counting_iter = thrust::counting_iterator<size_t>(0);
// Without __host__ annotation, cub crashes
auto linear_index_to_scan_key = [scan_size] __host__ __device__(size_t i) {
return i / scan_size;
}; // NOLINT(*)
return i / scan_size;
}; // NOLINT(*)
auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key);

if (exclusive && need_cast) {
thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr);
thrust::exclusive_scan_by_key(policy, key_iter, key_iter + size, data_cast_ptr, output_ptr);
} else if (exclusive && !need_cast) {
thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
thrust::exclusive_scan_by_key(policy, key_iter, key_iter + size, data_ptr, output_ptr);
} else if (!exclusive && need_cast) {
thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr);
thrust::inclusive_scan_by_key(policy, key_iter, key_iter + size, data_cast_ptr, output_ptr);
} else {
thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
thrust::inclusive_scan_by_key(policy, key_iter, key_iter + size, data_ptr, output_ptr);
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/cuda/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class CUDAThreadEntry {
// get the threadlocal workspace
static CUDAThreadEntry* ThreadLocal();
};

inline cudaStream_t GetCUDAStream() { return CUDAThreadEntry::ThreadLocal()->stream; }

} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_CUDA_CUDA_COMMON_H_
Loading