diff --git a/include/rmm/exec_policy.hpp b/include/rmm/exec_policy.hpp index eacdfa187..21ebfd295 100644 --- a/include/rmm/exec_policy.hpp +++ b/include/rmm/exec_policy.hpp @@ -28,6 +28,8 @@ #include #include +#include + namespace rmm { /** * @addtogroup thrust_integrations @@ -47,6 +49,8 @@ using thrust_exec_policy_t = * that uses RMM for temporary memory allocation on the specified stream. */ class exec_policy : public thrust_exec_policy_t { + using async_resource_ref = cuda::mr::async_resource_ref; + public: /** * @brief Construct a new execution policy object @@ -54,8 +58,8 @@ class exec_policy : public thrust_exec_policy_t { * @param stream The stream on which to allocate temporary memory * @param mr The resource to use for allocating temporary memory */ - explicit exec_policy(cuda_stream_view stream = cuda_stream_default, - rmm::mr::device_memory_resource* mr = mr::get_current_device_resource()) + explicit exec_policy(cuda_stream_view stream = cuda_stream_default, + async_resource_ref mr = rmm::mr::get_current_device_resource()) : thrust_exec_policy_t( thrust::cuda::par(rmm::mr::thrust_allocator(stream, mr)).on(stream.value())) { @@ -77,10 +81,11 @@ using thrust_exec_policy_nosync_t = * are not required for correctness. */ class exec_policy_nosync : public thrust_exec_policy_nosync_t { + using async_resource_ref = cuda::mr::async_resource_ref; + public: - explicit exec_policy_nosync( - cuda_stream_view stream = cuda_stream_default, - rmm::mr::device_memory_resource* mr = mr::get_current_device_resource()) + explicit exec_policy_nosync(cuda_stream_view stream = cuda_stream_default, + async_resource_ref mr = rmm::mr::get_current_device_resource()) : thrust_exec_policy_nosync_t( thrust::cuda::par_nosync(rmm::mr::thrust_allocator(stream, mr)).on(stream.value())) {