Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix Flaky Topk (#12798)
Browse files Browse the repository at this point in the history
* fix flaky topk

* try to fix

* remove the usage of IndexFill

* fix

* add docstring
  • Loading branch information
sxjscience authored and eric-haibin-lin committed Oct 18, 2018
1 parent 7463810 commit 1ebbf94
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 42 deletions.
72 changes: 40 additions & 32 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,22 @@ inline void ParseTopKParam(const TShape& src_shape, const TopKParam& param, TSha

using namespace mshadow;


struct fill_ind_to_one {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, const int* indices, DType* out) {
out[indices[i]] = static_cast<DType>(1);
}
};

struct fill_ind {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, const int* indices, const DType* val,
int req, DType* out) {
KERNEL_ASSIGN(out[indices[i]], req, val[i]);
}
};

template<typename DType>
MSHADOW_FORCE_INLINE void TopKSort(const Tensor<cpu, 1, DType>& dat,
const Tensor<cpu, 1, int>& ind,
Expand Down Expand Up @@ -313,7 +329,8 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
const int M(dat.size(0)/N);
if (full_sort) {
// Divide workspace into two parts. The first one is needed to store batch ids.
const int id_size(sizeof(int)*ind.size(0));
size_t alignment = std::max(sizeof(DType), sizeof(int));
size_t id_size = PadBytes(sizeof(int) * ind.size(0), alignment);
Tensor<gpu, 1, int> batch_id(reinterpret_cast<int*>(work.dptr_), Shape1(ind.size(0)), s);
Tensor<gpu, 1, char> sort_work(work.dptr_+id_size, Shape1(work.size(0)-id_size), s);
mxnet::op::SortByKey(dat, ind, is_ascend, &sort_work);
Expand Down Expand Up @@ -364,12 +381,12 @@ void TopKImpl(const RunContext &ctx,
Tensor<xpu, 1, char> temp_workspace;
Tensor<xpu, 1, DType> sorted_dat;
Tensor<xpu, 1, int> indices, sel_indices;
Tensor<xpu, 2, DType> mask_val;
int batch_size, element_num; // number of batches + the size of each batch
int axis = 0;
bool do_transpose = false;
bool is_ascend = false;
int k = 0;
size_t alignment = std::max(sizeof(DType), sizeof(int));
TShape target_shape;
ParseTopKParam(src.shape_, param,
&target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
Expand All @@ -387,32 +404,28 @@ void TopKImpl(const RunContext &ctx,
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<DType, int, xpu>(src.Size()));
// Additional temp space for gpu full sorts for batch ids.
temp_size += sizeof(int) * src.Size();
temp_size += PadBytes(sizeof(int) * src.Size(), alignment);
// Temp space for cpu sorts.
temp_size = std::max(temp_size, sizeof(DType) * static_cast<size_t>(src.Size()));
index_t workspace_size = temp_size + sizeof(DType) * src.Size() + sizeof(int) * src.Size();
temp_size = std::max(temp_size, static_cast<size_t>(sizeof(DType) * src.Size()));
size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment)
+ PadBytes(sizeof(int) * src.Size(), alignment);
if (param.ret_typ == topk_enum::kReturnMask) {
workspace_size += sizeof(int) * batch_size * k + sizeof(DType) * batch_size * k;
workspace_size += PadBytes(sizeof(int) * batch_size * k, alignment);
}
workspace = resource.get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
char* workspace_curr_ptr = workspace.dptr_;
sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
Shape1(src.Size()), s); // contain sorted dat
workspace_curr_ptr += sizeof(DType) * src.Size();
workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
Shape1(src.Size()), s); // indices in the original matrix
workspace_curr_ptr += sizeof(int) * src.Size();
workspace_curr_ptr += PadBytes(sizeof(int) * src.Size(), alignment);

if (param.ret_typ == topk_enum::kReturnMask) {
sel_indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
Shape1(batch_size * k), s);
workspace_curr_ptr += sizeof(int) * batch_size * k;
mask_val = Tensor<xpu, 2, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
Shape2(batch_size * k, 1), s);
workspace_curr_ptr += sizeof(DType) * batch_size * k;
mask_val = scalar<DType>(1);
workspace_curr_ptr += PadBytes(sizeof(int) * batch_size * k, alignment);
CHECK_EQ(sel_indices.CheckContiguous(), true);
CHECK_EQ(mask_val.CheckContiguous(), true);
}

if (std::is_same<xpu, cpu>::value) {
Expand Down Expand Up @@ -458,8 +471,7 @@ void TopKImpl(const RunContext &ctx,
// Cast `ret_indices` from int to real_t could introduce conversion error when the element_num
// is large enough.
if (param.ret_typ == topk_enum::kReturnMask) {
Tensor<xpu, 2, DType> ret_mask =
ret[0].get_with_shape<xpu, 2, DType>(Shape2(ret[0].Size(), 1), s);
Tensor<xpu, 1, DType> ret_mask = ret[0].FlatTo1D<xpu, DType>(s);
ret_mask = scalar<DType>(0);
sel_indices = reshape(slice<1>(
inplace_reshape(indices,
Expand All @@ -475,7 +487,8 @@ void TopKImpl(const RunContext &ctx,
if (req[0] == kNullOp) {
return;
} else if (req[0] == kWriteTo) {
IndexFill(ret_mask, sel_indices, mask_val);
mxnet_op::Kernel<fill_ind_to_one, xpu>::Launch(s, batch_size * k,
sel_indices.dptr_, ret_mask.dptr_);
} else {
LOG(FATAL) << "req=" << req[0] << " is not supported yet.";
}
Expand Down Expand Up @@ -605,14 +618,11 @@ void TopKBackwardImpl(const OpContext &ctx,
<< "The total element_num is " << element_num << ", but the selected IDType can only represent "
<< mxnet::common::MaxIntegerValue<IDType>() << " elements";
Tensor<xpu, 1, int> workspace =
ctx.requested[0].get_space_typed<xpu, 1, int>(Shape1(batch_size * k * 2 + batch_size), s);
ctx.requested[0].get_space_typed<xpu, 1, int>(Shape1(batch_size * k + batch_size), s);
Tensor<xpu, 1, int> sel_indices =
Tensor<xpu, 1, int>(workspace.dptr_, Shape1(batch_size * k), s);
Tensor<xpu, 1, int> batch_shift =
Tensor<xpu, 1, int>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);
Tensor<xpu, 1, int> dummy_index =
Tensor<xpu, 1, int>(workspace.dptr_ + batch_size * k + batch_size,
Shape1(batch_size * k), s);

Tensor<xpu, 2, DType> out_grad =
inputs[0].get_with_shape<xpu, 2, DType>(Shape2(inputs[0].shape_.Size(), 1), s);
Expand Down Expand Up @@ -641,17 +651,15 @@ void TopKBackwardImpl(const OpContext &ctx,
Shape1(batch_size * k));
}
CHECK_EQ(sel_indices.CheckContiguous(), true);
if (kWriteTo == req[0]) {
in_grad = scalar<DType>(0);
IndexFill(in_grad, sel_indices, out_grad);
} else if (kAddTo == req[0]) {
// TODO(sxjscience) We can use AddTakeGrad in the future.
// However, the current implementation of AddTakeGrad is not so efficient.
mxnet_op::Kernel<range_fwd, xpu>::Launch(s, sel_indices.shape_.Size(), 1, 0, 1, kWriteTo,
dummy_index.dptr_);
mxnet::op::AddTakeGradLargeBatch(in_grad, sel_indices, dummy_index, out_grad);
} else if (kNullOp == req[0]) {
return;
if (kWriteTo == req[0] || kAddTo == req[0]) {
if (kWriteTo == req[0]) {
in_grad = scalar<DType>(0);
}
mxnet_op::Kernel<fill_ind, xpu>::Launch(s, batch_size * k,
sel_indices.dptr_,
out_grad.dptr_,
req[0],
in_grad.dptr_);
} else {
LOG(FATAL) << "Not Implemented!";
}
Expand Down
10 changes: 6 additions & 4 deletions src/operator/tensor/sort_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ SortByKeyWorkspaceSize(const size_t num_keys) {
size_t sortpairs_bytes = 0;
cub::DeviceRadixSort::SortPairs<KDType, VDType>(NULL, sortpairs_bytes,
NULL, NULL, NULL, NULL, num_keys);
size_t keys_bytes = num_keys*sizeof(KDType);
size_t values_bytes = num_keys*sizeof(VDType);
size_t alignment = std::max(sizeof(KDType), sizeof(VDType));
size_t keys_bytes = PadBytes(num_keys*sizeof(KDType), alignment);
size_t values_bytes = PadBytes(num_keys*sizeof(VDType), alignment);
return (keys_bytes + values_bytes + sortpairs_bytes);
#endif
}
Expand All @@ -96,8 +97,9 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
// Workspace given, sort using CUB
CHECK_EQ(workspace->CheckContiguous(), true);
// workspace = [keys_out, values_out, temporary_storage]
size_t keys_bytes = keys.size(0)*sizeof(KDType);
size_t values_bytes = keys.size(0)*sizeof(VDType);
size_t alignment = std::max(sizeof(KDType), sizeof(VDType));
size_t keys_bytes = PadBytes(keys.size(0)*sizeof(KDType), alignment);
size_t values_bytes = PadBytes(keys.size(0)*sizeof(VDType), alignment);
// Get the size of internal storage (for checking purposes only)
size_t sortpairs_bytes = 0;
if (is_ascend) {
Expand Down
12 changes: 12 additions & 0 deletions src/operator/tensor/sort_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@
#include <type_traits>

namespace mxnet {

/*!
* \brief Return the required number of bytes for aligning an object.
Because CUDA requires mandatory memory alignment, this function can be
used to determine the number of bytes to allocate in char*.
* \param num_bytes size of the object in bytes
* \param alignment desired alignment, like 2, 4, 8
*/
MSHADOW_XINLINE size_t PadBytes(size_t num_bytes, size_t alignment) {
return num_bytes + (alignment - num_bytes % alignment) % alignment;
}

namespace op {
/*!
* \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!)
Expand Down
10 changes: 4 additions & 6 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def get_large_matrix():
gt = gt_topk(large_matrix_npy, axis=1, ret_typ="indices", k=5, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)

for dtype in [ np.int32, np.int64, np.float32, np.float64]:
for dtype in [np.int32, np.int64, np.float32, np.float64]:
a_npy = get_values(ensure_unique=True, dtype=dtype)
a_nd = mx.nd.array(a_npy, ctx=ctx, dtype=dtype)

Expand Down Expand Up @@ -754,9 +754,6 @@ def get_large_matrix():
assert_almost_equal(nd_ret_topk, gt)

# test for ret_typ=mask
# test needs to be re-enabled once flaky topk gets fixed
# tracked in https://github.com/apache/incubator-mxnet/pull/12446
'''
nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy()
assert nd_ret_topk.dtype == dtype
gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True)
Expand All @@ -767,7 +764,7 @@ def get_large_matrix():
nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
'''

# test for ret_typ=both
nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
nd_ret_topk_val = nd_ret_topk_val.asnumpy()
Expand Down Expand Up @@ -800,6 +797,7 @@ def get_large_matrix():
# test for argsort
for idtype in [np.int32, np.float16, np.float32, np.float64]:
nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True, dtype=idtype).asnumpy()
assert nd_ret_argsort.dtype == idtype
gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True)
assert_almost_equal(nd_ret_argsort, gt)
nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False, dtype=idtype).asnumpy()
Expand Down Expand Up @@ -863,7 +861,7 @@ def get_large_matrix():
# Repeat those tests that don't involve indices. These should pass even with
# duplicated input data values (over many repeated runs with different random seeds,
# this will be tested).
for dtype in [ np.int32, np.int64, np.float32, np.float64]:
for dtype in [np.int32, np.int64, np.float32, np.float64]:
a_npy = get_values(ensure_unique=False, dtype=dtype)
a_nd = mx.nd.array(a_npy, ctx=ctx, dtype=dtype)

Expand Down

0 comments on commit 1ebbf94

Please sign in to comment.