Skip to content

Commit

Permalink
Force inputs to be copied on device before the optimize step
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Oct 24, 2024
1 parent 2149e39 commit bad6bfb
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 72 deletions.
43 changes: 18 additions & 25 deletions tests/unit_tests/sparse_blas/source/sparse_spmm_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,21 @@ int test_spmm(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,
fpType *b_usm = b_usm_uptr.get();
fpType *c_usm = c_usm_uptr.get();

std::vector<sycl::event> mat_dependencies;
std::vector<sycl::event> spmm_dependencies;
std::vector<sycl::event> dependencies;
// Copy host to device
mat_dependencies.push_back(
dependencies.push_back(
main_queue.memcpy(ia_usm, ia_host.data(), ia_host.size() * sizeof(intType)));
mat_dependencies.push_back(
dependencies.push_back(
main_queue.memcpy(ja_usm, ja_host.data(), ja_host.size() * sizeof(intType)));
mat_dependencies.push_back(
main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType)));
spmm_dependencies.push_back(
main_queue.memcpy(b_usm, b_host.data(), b_host.size() * sizeof(fpType)));
spmm_dependencies.push_back(
main_queue.memcpy(c_usm, c_host.data(), c_host.size() * sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(b_usm, b_host.data(), b_host.size() * sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(c_usm, c_host.data(), c_host.size() * sizeof(fpType)));

fpType *alpha_host_or_usm_ptr = &alpha;
fpType *beta_host_or_usm_ptr = &beta;
if (test_scalar_on_device) {
spmm_dependencies.push_back(
main_queue.memcpy(alpha_usm_uptr.get(), &alpha, sizeof(fpType)));
spmm_dependencies.push_back(main_queue.memcpy(beta_usm_uptr.get(), &beta, sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(alpha_usm_uptr.get(), &alpha, sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(beta_usm_uptr.get(), &beta, sizeof(fpType)));
alpha_host_or_usm_ptr = alpha_usm_uptr.get();
beta_host_or_usm_ptr = beta_usm_uptr.get();
}
Expand Down Expand Up @@ -138,12 +133,10 @@ int test_spmm(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,
sycl::event ev_opt;
CALL_RT_OR_CT(ev_opt = oneapi::mkl::sparse::spmm_optimize, main_queue, transpose_A,
transpose_B, &alpha, A_view, A_handle, B_handle, &beta, C_handle, alg, descr,
workspace_usm.get(), mat_dependencies);
workspace_usm.get(), dependencies);

spmm_dependencies.push_back(ev_opt);
CALL_RT_OR_CT(ev_spmm = oneapi::mkl::sparse::spmm, main_queue, transpose_A, transpose_B,
&alpha, A_view, A_handle, B_handle, &beta, C_handle, alg, descr,
spmm_dependencies);
&alpha, A_view, A_handle, B_handle, &beta, C_handle, alg, descr, { ev_opt });

if (reset_data) {
intType reset_nnz = generate_random_matrix<fpType, intType>(
Expand All @@ -163,14 +156,14 @@ int test_spmm(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,
}
nnz = reset_nnz;

mat_dependencies.clear();
mat_dependencies.push_back(main_queue.memcpy(
ia_usm, ia_host.data(), ia_host.size() * sizeof(intType), ev_spmm));
mat_dependencies.push_back(main_queue.memcpy(
ja_usm, ja_host.data(), ja_host.size() * sizeof(intType), ev_spmm));
mat_dependencies.push_back(
dependencies.clear();
dependencies.push_back(main_queue.memcpy(ia_usm, ia_host.data(),
ia_host.size() * sizeof(intType), ev_spmm));
dependencies.push_back(main_queue.memcpy(ja_usm, ja_host.data(),
ja_host.size() * sizeof(intType), ev_spmm));
dependencies.push_back(
main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType), ev_spmm));
mat_dependencies.push_back(
dependencies.push_back(
main_queue.memcpy(c_usm, c_host.data(), c_host.size() * sizeof(fpType), ev_spmm));
set_matrix_data(main_queue, format, A_handle, nrows_A, ncols_A, nnz, index, ia_usm,
ja_usm, a_usm);
Expand All @@ -185,7 +178,7 @@ int test_spmm(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,

CALL_RT_OR_CT(ev_opt = oneapi::mkl::sparse::spmm_optimize, main_queue, transpose_A,
transpose_B, &alpha, A_view, A_handle, B_handle, &beta, C_handle, alg,
descr, workspace_usm.get(), mat_dependencies);
descr, workspace_usm.get(), dependencies);

CALL_RT_OR_CT(ev_spmm = oneapi::mkl::sparse::spmm, main_queue, transpose_A, transpose_B,
&alpha, A_view, A_handle, B_handle, &beta, C_handle, alg, descr,
Expand Down
42 changes: 18 additions & 24 deletions tests/unit_tests/sparse_blas/source/sparse_spmv_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,26 +78,21 @@ int test_spmv(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,
fpType *x_usm = x_usm_uptr.get();
fpType *y_usm = y_usm_uptr.get();

std::vector<sycl::event> mat_dependencies;
std::vector<sycl::event> spmv_dependencies;
std::vector<sycl::event> dependencies;
// Copy host to device
mat_dependencies.push_back(
dependencies.push_back(
main_queue.memcpy(ia_usm, ia_host.data(), ia_host.size() * sizeof(intType)));
mat_dependencies.push_back(
dependencies.push_back(
main_queue.memcpy(ja_usm, ja_host.data(), ja_host.size() * sizeof(intType)));
mat_dependencies.push_back(
main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType)));
spmv_dependencies.push_back(
main_queue.memcpy(x_usm, x_host.data(), x_host.size() * sizeof(fpType)));
spmv_dependencies.push_back(
main_queue.memcpy(y_usm, y_host.data(), y_host.size() * sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(x_usm, x_host.data(), x_host.size() * sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(y_usm, y_host.data(), y_host.size() * sizeof(fpType)));

fpType *alpha_host_or_usm_ptr = &alpha;
fpType *beta_host_or_usm_ptr = &beta;
if (test_scalar_on_device) {
spmv_dependencies.push_back(
main_queue.memcpy(alpha_usm_uptr.get(), &alpha, sizeof(fpType)));
spmv_dependencies.push_back(main_queue.memcpy(beta_usm_uptr.get(), &beta, sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(alpha_usm_uptr.get(), &alpha, sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(beta_usm_uptr.get(), &beta, sizeof(fpType)));
alpha_host_or_usm_ptr = alpha_usm_uptr.get();
beta_host_or_usm_ptr = beta_usm_uptr.get();
}
Expand Down Expand Up @@ -130,12 +125,11 @@ int test_spmv(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,
sycl::event ev_opt;
CALL_RT_OR_CT(ev_opt = oneapi::mkl::sparse::spmv_optimize, main_queue, transpose_val,
alpha_host_or_usm_ptr, A_view, A_handle, x_handle, beta_host_or_usm_ptr,
y_handle, alg, descr, workspace_usm.get(), mat_dependencies);
y_handle, alg, descr, workspace_usm.get(), dependencies);

spmv_dependencies.push_back(ev_opt);
CALL_RT_OR_CT(ev_spmv = oneapi::mkl::sparse::spmv, main_queue, transpose_val,
alpha_host_or_usm_ptr, A_view, A_handle, x_handle, beta_host_or_usm_ptr,
y_handle, alg, descr, spmv_dependencies);
y_handle, alg, descr, { ev_opt });

if (reset_data) {
intType reset_nnz = generate_random_matrix<fpType, intType>(
Expand All @@ -155,14 +149,14 @@ int test_spmv(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,
}
nnz = reset_nnz;

mat_dependencies.clear();
mat_dependencies.push_back(main_queue.memcpy(
ia_usm, ia_host.data(), ia_host.size() * sizeof(intType), ev_spmv));
mat_dependencies.push_back(main_queue.memcpy(
ja_usm, ja_host.data(), ja_host.size() * sizeof(intType), ev_spmv));
mat_dependencies.push_back(
dependencies.clear();
dependencies.push_back(main_queue.memcpy(ia_usm, ia_host.data(),
ia_host.size() * sizeof(intType), ev_spmv));
dependencies.push_back(main_queue.memcpy(ja_usm, ja_host.data(),
ja_host.size() * sizeof(intType), ev_spmv));
dependencies.push_back(
main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType), ev_spmv));
mat_dependencies.push_back(
dependencies.push_back(
main_queue.memcpy(y_usm, y_host.data(), y_host.size() * sizeof(fpType), ev_spmv));
set_matrix_data(main_queue, format, A_handle, nrows_A, ncols_A, nnz, index, ia_usm,
ja_usm, a_usm);
Expand All @@ -177,7 +171,7 @@ int test_spmv(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,

CALL_RT_OR_CT(ev_opt = oneapi::mkl::sparse::spmv_optimize, main_queue, transpose_val,
alpha_host_or_usm_ptr, A_view, A_handle, x_handle, beta_host_or_usm_ptr,
y_handle, alg, descr, workspace_usm.get(), mat_dependencies);
y_handle, alg, descr, workspace_usm.get(), dependencies);

CALL_RT_OR_CT(ev_spmv = oneapi::mkl::sparse::spmv, main_queue, transpose_val,
alpha_host_or_usm_ptr, A_view, A_handle, x_handle, beta_host_or_usm_ptr,
Expand Down
40 changes: 17 additions & 23 deletions tests/unit_tests/sparse_blas/source/sparse_spsv_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,19 @@ int test_spsv(sycl::device *dev, sparse_matrix_format_t format, intType m, doubl
fpType *x_usm = x_usm_uptr.get();
fpType *y_usm = y_usm_uptr.get();

std::vector<sycl::event> mat_dependencies;
std::vector<sycl::event> spsv_dependencies;
std::vector<sycl::event> dependencies;
// Copy host to device
mat_dependencies.push_back(
dependencies.push_back(
main_queue.memcpy(ia_usm, ia_host.data(), ia_host.size() * sizeof(intType)));
mat_dependencies.push_back(
dependencies.push_back(
main_queue.memcpy(ja_usm, ja_host.data(), ja_host.size() * sizeof(intType)));
mat_dependencies.push_back(
main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType)));
spsv_dependencies.push_back(
main_queue.memcpy(x_usm, x_host.data(), x_host.size() * sizeof(fpType)));
spsv_dependencies.push_back(
main_queue.memcpy(y_usm, y_host.data(), y_host.size() * sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(x_usm, x_host.data(), x_host.size() * sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(y_usm, y_host.data(), y_host.size() * sizeof(fpType)));

fpType *alpha_host_or_usm_ptr = &alpha;
if (test_scalar_on_device) {
spsv_dependencies.push_back(
main_queue.memcpy(alpha_usm_uptr.get(), &alpha, sizeof(fpType)));
dependencies.push_back(main_queue.memcpy(alpha_usm_uptr.get(), &alpha, sizeof(fpType)));
alpha_host_or_usm_ptr = alpha_usm_uptr.get();
}

Expand Down Expand Up @@ -128,12 +123,11 @@ int test_spsv(sycl::device *dev, sparse_matrix_format_t format, intType m, doubl
sycl::event ev_opt;
CALL_RT_OR_CT(ev_opt = oneapi::mkl::sparse::spsv_optimize, main_queue, transpose_val,
alpha_host_or_usm_ptr, A_view, A_handle, x_handle, y_handle, alg, descr,
workspace_usm.get(), mat_dependencies);
workspace_usm.get(), dependencies);

spsv_dependencies.push_back(ev_opt);
CALL_RT_OR_CT(ev_spsv = oneapi::mkl::sparse::spsv, main_queue, transpose_val,
alpha_host_or_usm_ptr, A_view, A_handle, x_handle, y_handle, alg, descr,
spsv_dependencies);
{ ev_opt });

if (reset_data) {
intType reset_nnz = generate_random_matrix<fpType, intType>(
Expand All @@ -152,14 +146,14 @@ int test_spsv(sycl::device *dev, sparse_matrix_format_t format, intType m, doubl
}
nnz = reset_nnz;

mat_dependencies.clear();
mat_dependencies.push_back(main_queue.memcpy(
ia_usm, ia_host.data(), ia_host.size() * sizeof(intType), ev_spsv));
mat_dependencies.push_back(main_queue.memcpy(
ja_usm, ja_host.data(), ja_host.size() * sizeof(intType), ev_spsv));
mat_dependencies.push_back(
dependencies.clear();
dependencies.push_back(main_queue.memcpy(ia_usm, ia_host.data(),
ia_host.size() * sizeof(intType), ev_spsv));
dependencies.push_back(main_queue.memcpy(ja_usm, ja_host.data(),
ja_host.size() * sizeof(intType), ev_spsv));
dependencies.push_back(
main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType), ev_spsv));
mat_dependencies.push_back(
dependencies.push_back(
main_queue.memcpy(y_usm, y_host.data(), y_host.size() * sizeof(fpType), ev_spsv));
set_matrix_data(main_queue, format, A_handle, m, m, nnz, index, ia_usm, ja_usm, a_usm);

Expand All @@ -173,7 +167,7 @@ int test_spsv(sycl::device *dev, sparse_matrix_format_t format, intType m, doubl

CALL_RT_OR_CT(ev_opt = oneapi::mkl::sparse::spsv_optimize, main_queue, transpose_val,
alpha_host_or_usm_ptr, A_view, A_handle, x_handle, y_handle, alg, descr,
workspace_usm.get(), mat_dependencies);
workspace_usm.get(), dependencies);

CALL_RT_OR_CT(ev_spsv = oneapi::mkl::sparse::spsv, main_queue, transpose_val,
alpha_host_or_usm_ptr, A_view, A_handle, x_handle, y_handle, alg, descr,
Expand Down

0 comments on commit bad6bfb

Please sign in to comment.