diff --git a/sycl/test-e2e/Matrix/joint_matrix_transposeAB.cpp b/sycl/test-e2e/Matrix/joint_matrix_transposeAB.cpp index 6781c8ca40f3..51714b6ba9d9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_transposeAB.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_transposeAB.cpp @@ -22,7 +22,8 @@ template class MT; -template +template void matrix_transpose(T *input, T *out_col_major, queue q) { static_assert((NR % TR) == 0); static_assert((NC % TC) == 0); @@ -52,19 +53,25 @@ void matrix_transpose(T *input, T *out_col_major, queue q) { sub_group sg = spmd_item.get_sub_group(); joint_matrix matrix_input; - joint_matrix matrix_col_major; + joint_matrix + matrix_col_major; - auto input_offset = (sg_startx * TR / VF) * NC * VF + sg_starty / sg_size * TC * VF; - auto col_major_offset = (sg_startx * TR) + (sg_starty / sg_size * TC) * NR; + auto input_offset = + (sg_startx * TR / VF) * NC * VF + sg_starty / sg_size * TC * VF; + auto col_major_offset = + (sg_startx * TR) + (sg_starty / sg_size * TC) * NR; joint_matrix_load(sg, matrix_input, p_input + input_offset, NC * VF); joint_matrix_copy(sg, matrix_input, matrix_col_major); - joint_matrix_store(sg, matrix_col_major, p_out_col_major + col_major_offset, NR); + joint_matrix_store(sg, matrix_col_major, + p_out_col_major + col_major_offset, NR); }); // parallel for }).wait(); } -template void test() { +template +void test() { static constexpr size_t SCALE = 2; static constexpr size_t MATRIX_R = TR * SCALE; static constexpr size_t MATRIX_C = TC * SCALE; @@ -72,21 +79,24 @@ template (MATRIX_R * MATRIX_C, q); - T* vnni = malloc_shared(MATRIX_R * MATRIX_C, q); + T *vnni = malloc_shared(MATRIX_R * MATRIX_C, q); T *col_major = malloc_shared(MATRIX_C * MATRIX_R, q); T *ref_col_major = malloc_shared(MATRIX_C * MATRIX_R, q); matrix_rand(MATRIX_R, MATRIX_C, in, (T)5.0); if constexpr (VF != 1) { matrix_vnni(MATRIX_R, MATRIX_C, in, vnni, VF); - matrix_transpose(vnni, col_major, q); + matrix_transpose( + vnni, col_major, q); } else { - matrix_transpose(in, col_major, q); + matrix_transpose( + in, col_major, q); } matrix_transpose(MATRIX_R, MATRIX_C, ref_col_major, in); std::cout << "compare results for: " << TR << " x " << TC << std::endl; - assert(matrix_compare(MATRIX_C, MATRIX_R, col_major, ref_col_major)); + assert( + matrix_compare(MATRIX_C, MATRIX_R, col_major, ref_col_major)); free(in, q); free(vnni, q);