Skip to content

Commit

Permalink
Moves from elements_with_stride to independent_group_elements
Browse files Browse the repository at this point in the history
  • Loading branch information
fwyzard authored and ericcano committed Jan 19, 2024
1 parent 4af44b3 commit 9af4710
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
32 changes: 20 additions & 12 deletions HeterogeneousCore/AlpakaInterface/interface/radixSort.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#ifndef HeterogeneousCore_AlpakaInterface_interface_radixSort_h
#define HeterogeneousCore_AlpakaInterface_interface_radixSort_h

#include <algorithm>
#include <cstdint>
#include <numeric>
#include <type_traits>

#include <alpaka/alpaka.hpp>

#include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"

namespace cms::alpakatools {
Expand Down Expand Up @@ -125,17 +129,17 @@ namespace cms::alpakatools {
auto k = ind2;

// Initializer index order to trivial increment.
for (auto idx: elements_with_stride(acc, size)) { j[idx] = idx; }
for (auto idx: independent_group_elements(acc, size)) { j[idx] = idx; }
alpaka::syncBlockThreads(acc);

// Iterate on the slices of the data.
while (alpaka::syncBlockThreadsPredicate<alpaka::BlockAnd>(acc, (currentSortingPass < totalSortingPassses))) {
for (auto idx: elements_with_stride(acc, binsNumber)) { c[idx] = 0; }
for (auto idx: independent_group_elements(acc, binsNumber)) { c[idx] = 0; }
alpaka::syncBlockThreads(acc);
const auto sortingPassShift = binBits * currentSortingPass;

// fill bins (count elements in each bin)
for (auto idx: elements_with_stride(acc, size)) {
for (auto idx: independent_group_elements(acc, size)) {
auto bin = (a[j[idx]] >> sortingPassShift) & binsMask;
alpaka::atomicAdd(acc, &c[bin], 1, alpaka::hierarchy::Threads{});
}
Expand All @@ -154,7 +158,7 @@ namespace cms::alpakatools {
// prefix scan "optimized"???...
// TODO: we might be able to reuse the warpPrefixScan function
// Warp level prefix scan
for (auto idx: elements_with_stride(acc, binsNumber)) {
for (auto idx: independent_group_elements(acc, binsNumber)) {
auto x = c[idx];
auto laneId = idx & warpMask;

Expand All @@ -168,7 +172,7 @@ namespace cms::alpakatools {
alpaka::syncBlockThreads(acc);

// Block level completion of prefix scan (add last sum of each preceding warp)
for (auto idx: elements_with_stride(acc, binsNumber)) {
for (auto idx: independent_group_elements(acc, binsNumber)) {
auto ss = (idx / warpSize) * warpSize - 1;
c[idx] = ct[idx];
for (int i = ss; i > 0; i -= warpSize)
Expand Down Expand Up @@ -198,15 +202,15 @@ namespace cms::alpakatools {
// Iterate on bin-sized slices to (size - 1) / binSize + 1 iterations
while (alpaka::syncBlockThreadsPredicate<alpaka::BlockAnd>(acc, ibs >= 0)) {
// Init
for (auto idx: elements_with_stride(acc, binsNumber)) {
for (auto idx: independent_group_elements(acc, binsNumber)) {
cu[idx] = -1;
ct[idx] = -1;
}
alpaka::syncBlockThreads(acc);

// Find the highest index for all the threads dealing with a given bin (in cu[])
// Also record the bin for each thread (in ct[])
for (auto idx: elements_with_stride(acc, binsNumber)) {
for (auto idx: independent_group_elements(acc, binsNumber)) {
int i = ibs - idx;
int32_t bin = -1;
if (i >= 0) {
Expand All @@ -226,7 +230,7 @@ namespace cms::alpakatools {


// FIXME: we can slash a memory access.
for (auto idx: elements_with_stride(acc, binsNumber)) {
for (auto idx: independent_group_elements(acc, binsNumber)) {
int i = ibs - idx;
// Are we still in inside the data?
if (i >= 0) {
Expand Down Expand Up @@ -302,7 +306,7 @@ namespace cms::alpakatools {

// TODO this copy is (doubly?) redundant with the reorder
if (j != ind) // odd number of sorting passes, we need to move the result to the right array (ind[])
for (auto idx: elements_with_stride(acc, size)) { ind[idx] = ind2[idx]; };
for (auto idx: independent_group_elements(acc, size)) { ind[idx] = ind2[idx]; };

alpaka::syncBlockThreads(acc);

Expand Down Expand Up @@ -360,8 +364,9 @@ namespace cms::alpakatools {
const TAcc& acc, T const* a, uint16_t* ind, uint16_t* ind2, uint32_t size) {
static_assert(requires_single_thread_per_block_v<TAcc>, "CPU sort (not a radixSort) called wtth wrong accelerator");
// Initialize the index array
for (std::size_t i = 0; i < size; i++) ind[i] = i;
printf("std::sort(a=%p, ind=%p, indmax=%p, size=%d)\n", a, ind, ind + size, size);
std::iota(ind, ind + size, 0);
/*
printf("std::stable_sort(a=%p, ind=%p, indmax=%p, size=%d)\n", a, ind, ind + size, size);
for (uint32_t i=0; i<10 && i<size; i++) {
printf ("a[%d]=%ld ", i, (long int)a[i]);
}
Expand All @@ -370,11 +375,14 @@ namespace cms::alpakatools {
printf ("ind[%d]=%d ", i, ind[i]);
}
printf("\n");
std::sort(ind, ind+size, [a](uint16_t i0, uint16_t i1) { return a[i0] < a[i1]; });
*/
std::stable_sort(ind, ind+size, [a](uint16_t i0, uint16_t i1) { return a[i0] < a[i1]; });
/*
for (uint32_t i=0; i<10 && i<size; i++) {
printf ("ind[%d]=%d ", i, ind[i]);
}
printf("\n");
*/
}

template <typename TAcc, typename T, int NS = sizeof(T)>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ void go(Queue & queue, bool useShared) {
if (offsets_h[ib + 1] > offsets_h[ib])
inds.insert(ind_h[offsets_h[ib]]);
for (auto j = offsets_h[ib] + 1; j < offsets_h[ib + 1]; j++) {
if (inds.count(ind_h[j])) {
if (inds.count(ind_h[j]) != 0) {
printf("i=%d ib=%d ind_h[j=%d]=%d: duplicate indice!\n",
i, ib, j, ind_h[j]);
std::vector<int> counts;
Expand All @@ -206,6 +206,7 @@ void go(Queue & queue, bool useShared) {
printf("counts[%ld]=%d ", j2, counts[j2]);
}
printf("\n");
printf("inds.count(ind_h[j] = %lu\n", inds.count(ind_h[j]));
}
assert(0 == inds.count(ind_h[j]));
inds.insert(ind_h[j]);
Expand All @@ -218,7 +219,6 @@ void go(Queue & queue, bool useShared) {
std::cout << "i=" << i << " not ordered at ib=" << ib << " in [" << offsets_h[ib] << ", " << offsets_h[ib + 1] - 1
<< "] j=" << j << " ind[j]=" << ind_h[j] << " (k1 < k2) : a1=" << a[ind_h[j]] << " k1=" << k1
<< "a2= " << a[ind_h[j - 1]] << " k2=" << k2 << std::endl;
assert(false);
}
}
if (!inds.empty()) {
Expand Down

0 comments on commit 9af4710

Please sign in to comment.