Skip to content

Commit

Permalink
take SYCL implementation from DPCT
Browse files Browse the repository at this point in the history
  • Loading branch information
AuroraPerego committed Dec 12, 2023
1 parent 97482f0 commit 8c31a85
Showing 1 changed file with 33 additions and 31 deletions.
64 changes: 33 additions & 31 deletions include/alpaka/warp/WarpGenericSycl.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
/* Copyright 2023 Jan Stephan, Luca Ferragina, Andrea Bocci, Aurora Perego
* SPDX-License-Identifier: MPL-2.0
*
* The implementations of Shfl::shfl(), ShflUp::shfl_up(), ShflDown::shfl_down() and ShflXor::shfl_xor() are derived
* from Intel DPCT.
* Copyright (C) Intel Corporation.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
* See https://llvm.org/LICENSE.txt for license information.
*/

#pragma once
Expand Down Expand Up @@ -123,12 +129,9 @@ namespace alpaka::warp::trait
The first starts at sub-group index 0 and the second at sub-group index 16. For srcLane = 4 the
first subdivision will access the value at sub-group index 4 and the second at sub-group index 20. */
auto const actual_group = warp.m_item_warp.get_sub_group();
auto const actual_item_id = static_cast<std::int32_t>(actual_group.get_local_linear_id());
auto const actual_group_id = actual_item_id / width;
auto const actual_src_id = static_cast<std::size_t>(srcLane + actual_group_id * width);
auto const src = sycl::id<1>{actual_src_id};

return sycl::select_from_group(actual_group, value, src);
std::uint32_t const w = static_cast<std::uint32_t>(width);
std::uint32_t const start_index = actual_group.get_local_linear_id() / w * w;
return sycl::select_from_group(actual_group, value, start_index + static_cast<std::uint32_t>(srcLane) % w);
}
};

Expand All @@ -142,15 +145,16 @@ namespace alpaka::warp::trait
std::uint32_t offset, /* must be the same for all work-items in the group */
std::int32_t width)
{
std::int32_t offset_int = static_cast<std::int32_t>(offset);
auto const actual_group = warp.m_item_warp.get_sub_group();
auto actual_item_id = static_cast<std::int32_t>(actual_group.get_local_linear_id());
auto const actual_group_id = actual_item_id / width;
auto const actual_src_id = actual_item_id - offset_int;
auto const src = actual_src_id >= actual_group_id * width
? sycl::id<1>{static_cast<std::size_t>(actual_src_id)}
: sycl::id<1>{static_cast<std::size_t>(actual_item_id)};
return sycl::select_from_group(actual_group, value, src);
std::uint32_t const w = static_cast<std::uint32_t>(width);
std::uint32_t const id = actual_group.get_local_linear_id();
std::uint32_t const start_index = id / w * w;
T result = sycl::shift_group_right(actual_group, value, offset);
if((id - start_index) < offset)
{
result = value;
}
return result;
}
};

Expand All @@ -164,33 +168,31 @@ namespace alpaka::warp::trait
std::uint32_t offset,
std::int32_t width)
{
std::int32_t offset_int = static_cast<std::int32_t>(offset);
auto const actual_group = warp.m_item_warp.get_sub_group();
auto actual_item_id = static_cast<std::int32_t>(actual_group.get_local_linear_id());
auto const actual_group_id = actual_item_id / width;
auto const actual_src_id = actual_item_id + offset_int;
auto const src = actual_src_id < (actual_group_id + 1) * width
? sycl::id<1>{static_cast<std::size_t>(actual_src_id)}
: sycl::id<1>{static_cast<std::size_t>(actual_item_id)};
return sycl::select_from_group(actual_group, value, src);
std::uint32_t const w = static_cast<std::uint32_t>(width);
std::uint32_t const id = actual_group.get_local_linear_id();
std::uint32_t const end_index = (id / w + 1) * w;
T result = sycl::shift_group_left(actual_group, value, offset);
if((id + offset) >= end_index)
{
result = value;
}
return result;
}
};

template<typename TDim>
struct ShflXor<warp::WarpGenericSycl<TDim>>
{
template<typename T>
static auto shfl_xor(
warp::WarpGenericSycl<TDim> const& warp,
T value,
std::int32_t mask,
std::int32_t /*width*/)
static auto shfl_xor(warp::WarpGenericSycl<TDim> const& warp, T value, std::int32_t mask, std::int32_t width)
{
auto const actual_group = warp.m_item_warp.get_sub_group();
auto actual_item_id = static_cast<std::int32_t>(actual_group.get_local_linear_id());
auto const actual_src_id = actual_item_id ^ mask;
auto const src = sycl::id<1>{static_cast<std::size_t>(actual_src_id)};
return sycl::select_from_group(actual_group, value, src);
std::uint32_t const w = static_cast<std::uint32_t>(width);
std::uint32_t const id = actual_group.get_local_linear_id();
std::uint32_t const start_index = id / w * w;
std::uint32_t const target_offset = (id % w) ^ static_cast<std::uint32_t>(mask);
return sycl::select_from_group(actual_group, value, target_offset < w ? start_index + target_offset : id);
}
};
} // namespace alpaka::warp::trait
Expand Down

0 comments on commit 8c31a85

Please sign in to comment.