Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriPlyakhin committed Jan 24, 2025
1 parent 84781dd commit d35cb0b
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions sycl/test-e2e/Matrix/joint_matrix_transposeAB.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

template <size_t TileRows, size_t TileCols> class MT;

template <size_t TR, size_t TC, typename T, size_t NR, size_t NC, use Use, layout LoadLayout, size_t VF>
template <size_t TR, size_t TC, typename T, size_t NR, size_t NC, use Use,
layout LoadLayout, size_t VF>
void matrix_transpose(T *input, T *out_col_major, queue q) {
static_assert((NR % TR) == 0);
static_assert((NC % TC) == 0);
Expand Down Expand Up @@ -52,41 +53,50 @@ void matrix_transpose(T *input, T *out_col_major, queue q) {

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, T, Use, TR, TC, LoadLayout> matrix_input;
joint_matrix<sub_group, T, Use, TR, TC, layout::col_major> matrix_col_major;
joint_matrix<sub_group, T, Use, TR, TC, layout::col_major>
matrix_col_major;

auto input_offset = (sg_startx * TR / VF) * NC * VF + sg_starty / sg_size * TC * VF;
auto col_major_offset = (sg_startx * TR) + (sg_starty / sg_size * TC) * NR;
auto input_offset =
(sg_startx * TR / VF) * NC * VF + sg_starty / sg_size * TC * VF;
auto col_major_offset =
(sg_startx * TR) + (sg_starty / sg_size * TC) * NR;

joint_matrix_load(sg, matrix_input, p_input + input_offset, NC * VF);
joint_matrix_copy(sg, matrix_input, matrix_col_major);
joint_matrix_store(sg, matrix_col_major, p_out_col_major + col_major_offset, NR);
joint_matrix_store(sg, matrix_col_major,
p_out_col_major + col_major_offset, NR);
}); // parallel for
}).wait();
}

template <typename T, size_t TR, size_t TC, size_t VF, use Use, layout InputLayout> void test() {
template <typename T, size_t TR, size_t TC, size_t VF, use Use,
layout InputLayout>
void test() {
static constexpr size_t SCALE = 2;
static constexpr size_t MATRIX_R = TR * SCALE;
static constexpr size_t MATRIX_C = TC * SCALE;

queue q;

T *in = malloc_shared<T>(MATRIX_R * MATRIX_C, q);
T* vnni = malloc_shared<T>(MATRIX_R * MATRIX_C, q);
T *vnni = malloc_shared<T>(MATRIX_R * MATRIX_C, q);
T *col_major = malloc_shared<T>(MATRIX_C * MATRIX_R, q);
T *ref_col_major = malloc_shared<T>(MATRIX_C * MATRIX_R, q);

matrix_rand(MATRIX_R, MATRIX_C, in, (T)5.0);
if constexpr (VF != 1) {
matrix_vnni(MATRIX_R, MATRIX_C, in, vnni, VF);
matrix_transpose<TR, TC, T, MATRIX_R, MATRIX_C, Use, InputLayout, VF>(vnni, col_major, q);
matrix_transpose<TR, TC, T, MATRIX_R, MATRIX_C, Use, InputLayout, VF>(
vnni, col_major, q);
} else {
matrix_transpose<TR, TC, T, MATRIX_R, MATRIX_C, Use, InputLayout, VF>(in, col_major, q);
matrix_transpose<TR, TC, T, MATRIX_R, MATRIX_C, Use, InputLayout, VF>(
in, col_major, q);
}
matrix_transpose(MATRIX_R, MATRIX_C, ref_col_major, in);

std::cout << "compare results for: " << TR << " x " << TC << std::endl;
assert(matrix_compare<T, T, true>(MATRIX_C, MATRIX_R, col_major, ref_col_major));
assert(
matrix_compare<T, T, true>(MATRIX_C, MATRIX_R, col_major, ref_col_major));

free(in, q);
free(vnni, q);
Expand Down

0 comments on commit d35cb0b

Please sign in to comment.