-
Notifications
You must be signed in to change notification settings - Fork 75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added warp::shfl functionality. #1273
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
205666b
Added warp::shfl functionality.
frobnitzem beec8c2
removed unused param in Shfl<WarpSingleThread>
frobnitzem 8b68865
Specified 32-bit precision ints on shfl() and added width parameter.
frobnitzem 70f01ba
Compared float by error tolerance to avoid compiler warning.
frobnitzem 5f4e89f
remove magic number
psychocoderHPC File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
/* Copyright 2021 David M. Rogers | ||
bernhardmgruber marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* | ||
* This file is part of Alpaka. | ||
* | ||
* This Source Code Form is subject to the terms of the Mozilla Public | ||
* License, v. 2.0. If a copy of the MPL was not distributed with this | ||
* file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
*/ | ||
|
||
#include <alpaka/test/KernelExecutionFixture.hpp> | ||
#include <alpaka/test/acc/TestAccs.hpp> | ||
#include <alpaka/test/queue/Queue.hpp> | ||
#include <alpaka/warp/Traits.hpp> | ||
|
||
#include <catch2/catch.hpp> | ||
|
||
#include <cstdint> | ||
#include <limits> | ||
|
||
//############################################################################# | ||
class ShflSingleThreadWarpTestKernel | ||
{ | ||
public: | ||
//------------------------------------------------------------------------- | ||
ALPAKA_NO_HOST_ACC_WARNING | ||
template<typename TAcc> | ||
ALPAKA_FN_ACC auto operator()(TAcc const& acc, bool* success) const -> void | ||
{ | ||
std::int32_t const warpExtent = alpaka::warp::getSize(acc); | ||
ALPAKA_CHECK(*success, warpExtent == 1); | ||
|
||
ALPAKA_CHECK(*success, alpaka::warp::shfl(acc, 12, 0) == 12); | ||
ALPAKA_CHECK(*success, alpaka::warp::shfl(acc, 42, -1) == 42); | ||
float ans = alpaka::warp::shfl(acc, 3.3f, 0); | ||
ALPAKA_CHECK(*success, alpaka::math::abs(acc, ans - 3.3f) < 1e-8f); | ||
psychocoderHPC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
}; | ||
|
||
//############################################################################# | ||
class ShflMultipleThreadWarpTestKernel | ||
{ | ||
public: | ||
//----------------------------------------------------------------------------- | ||
ALPAKA_NO_HOST_ACC_WARNING | ||
template<typename TAcc> | ||
ALPAKA_FN_ACC auto operator()(TAcc const& acc, bool* success) const -> void | ||
{ | ||
auto const localThreadIdx = alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc); | ||
auto const blockExtent = alpaka::getWorkDiv<alpaka::Block, alpaka::Threads>(acc); | ||
std::int32_t const warpExtent = alpaka::warp::getSize(acc); | ||
// Test relies on having a single warp per thread block | ||
ALPAKA_CHECK(*success, static_cast<std::int32_t>(blockExtent.prod()) == warpExtent); | ||
auto const threadIdxInWarp = std::int32_t(alpaka::mapIdx<1u>(localThreadIdx, blockExtent)[0]); | ||
|
||
ALPAKA_CHECK(*success, warpExtent > 1); | ||
|
||
ALPAKA_CHECK(*success, alpaka::warp::shfl(acc, 42, 0) == 42); | ||
ALPAKA_CHECK(*success, alpaka::warp::shfl(acc, threadIdxInWarp, 0) == 0); | ||
ALPAKA_CHECK(*success, alpaka::warp::shfl(acc, threadIdxInWarp, 1) == 1); | ||
// Note the CUDA and HIP API-s differ on lane wrapping, but both agree it should not segfault | ||
// https://github.com/ROCm-Developer-Tools/HIP-CPU/issues/14 | ||
ALPAKA_CHECK(*success, alpaka::warp::shfl(acc, 5, -1) == 5); | ||
|
||
auto const epsilon = std::numeric_limits<float>::epsilon(); | ||
|
||
// Test various widths | ||
for(int width = 1; width < warpExtent; width *= 2) | ||
{ | ||
for(int idx = 0; idx < width; idx++) | ||
{ | ||
int const off = width * (threadIdxInWarp / width); | ||
ALPAKA_CHECK(*success, alpaka::warp::shfl(acc, threadIdxInWarp, idx, width) == idx + off); | ||
float const ans = alpaka::warp::shfl(acc, 4.0f - float(threadIdxInWarp), idx, width); | ||
float const expect = 4.0f - float(idx + off); | ||
ALPAKA_CHECK(*success, alpaka::math::abs(acc, ans - expect) < epsilon); | ||
} | ||
} | ||
|
||
// Some threads quit the kernel to test that the warp operations | ||
// properly operate on the active threads only | ||
if(threadIdxInWarp >= warpExtent / 2) | ||
return; | ||
|
||
for(int idx = 0; idx < warpExtent / 2; idx++) | ||
{ | ||
ALPAKA_CHECK(*success, alpaka::warp::shfl(acc, threadIdxInWarp, idx) == idx); | ||
float const ans = alpaka::warp::shfl(acc, 4.0f - float(threadIdxInWarp), idx); | ||
float const expect = 4.0f - float(idx); | ||
ALPAKA_CHECK(*success, alpaka::math::abs(acc, ans - expect) < epsilon); | ||
} | ||
} | ||
}; | ||
|
||
//----------------------------------------------------------------------------- | ||
TEMPLATE_LIST_TEST_CASE("shfl", "[warp]", alpaka::test::TestAccs) | ||
{ | ||
using Acc = TestType; | ||
using Dev = alpaka::Dev<Acc>; | ||
using Pltf = alpaka::Pltf<Dev>; | ||
using Dim = alpaka::Dim<Acc>; | ||
using Idx = alpaka::Idx<Acc>; | ||
|
||
Dev const dev(alpaka::getDevByIdx<Pltf>(0u)); | ||
auto const warpExtent = alpaka::getWarpSize(dev); | ||
if(warpExtent == 1) | ||
{ | ||
Idx const gridThreadExtentPerDim = 4; | ||
alpaka::test::KernelExecutionFixture<Acc> fixture(alpaka::Vec<Dim, Idx>::all(gridThreadExtentPerDim)); | ||
ShflSingleThreadWarpTestKernel kernel; | ||
REQUIRE(fixture(kernel)); | ||
} | ||
else | ||
{ | ||
// Work around gcc 7.5 trying and failing to offload for OpenMP 4.0 | ||
#if BOOST_COMP_GNUC && (BOOST_COMP_GNUC == BOOST_VERSION_NUMBER(7, 5, 0)) && defined ALPAKA_ACC_ANY_BT_OMP5_ENABLED | ||
return; | ||
#else | ||
using ExecutionFixture = alpaka::test::KernelExecutionFixture<Acc>; | ||
auto const gridBlockExtent = alpaka::Vec<Dim, Idx>::all(2); | ||
// Enforce one warp per thread block | ||
auto blockThreadExtent = alpaka::Vec<Dim, Idx>::ones(); | ||
blockThreadExtent[0] = static_cast<Idx>(warpExtent); | ||
auto const threadElementExtent = alpaka::Vec<Dim, Idx>::ones(); | ||
auto workDiv = typename ExecutionFixture::WorkDiv{gridBlockExtent, blockThreadExtent, threadElementExtent}; | ||
auto fixture = ExecutionFixture{workDiv}; | ||
ShflMultipleThreadWarpTestKernel kernel; | ||
REQUIRE(fixture(kernel)); | ||
#endif | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we need to add to the documentation that this function
shfl
is collective what means all threads need to call the function and also from the same code branch.The reason is that for CUDA the implementation is using
activemask
and for HIP all threads in a warp needs to call the function. Usingactivemask
means if threads from the if and else branch call the function they will not see each other.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated these docs to include this warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I forgot to add a similar warning to the previously existing warp collectives. You comment also alllies to those, right @psychocoderHPC ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sbastrakov Yes this should be added to other warp functions too. Currently, only CUDA allows calling warp functions from different branches. It is fine if all threads of the warp are in the same branch but as soon as the threads diverge the behavior is undefined (for HIP and CUDA devices before sm_70) .