Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NaN for softmax with long softmax_dim. #57851

Merged
merged 4 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/phi/kernels/funcs/aligned_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ struct NeedVectorized {
static constexpr bool value = sizeof(T) <= sizeof(float);
};

template <int N>
struct MaxWithOne {
static constexpr auto kValue = (N >= 1 ? N : 1);
};

// Aligned vector generates vectorized load/store on CUDA.
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
Expand Down
5 changes: 0 additions & 5 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,6 @@ HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx(
return dst_idx;
}

template <int N>
struct MaxWithOne {
static constexpr auto kValue = (N >= 1 ? N : 1);
};

template <int Index, int VecSize>
struct ReadVecDataWithInt64Index {
template <typename Array1, typename Array2, typename Array3, typename ArgsT>
Expand Down
92 changes: 28 additions & 64 deletions paddle/phi/kernels/gpudnn/softmax_gpudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,6 @@ limitations under the License. */
#define MATRIX_SOFTMAX_ALIGN_BYTES 16
#define MATRIX_SOFTMAX_THREAHOLD 100000

#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break

#define FIXED_VEC_SIZE_BASE(vec_size, ...) \
case (vec_size): { \
constexpr auto VecSize = (vec_size); \
__VA_ARGS__; \
} break

#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)

#define FIXED_VEC_SIZE(...) \
FIXED_VEC_SIZE_BASE(8, ##__VA_ARGS__); \
FIXED_VEC_SIZE_BASE(4, ##__VA_ARGS__)

namespace phi {

using ScopedTensorDescriptor = phi::backends::gpu::ScopedTensorDescriptor;
Expand Down Expand Up @@ -112,7 +89,7 @@ static inline int Log2Ceil(int value) {
return log2_value;
}

inline int getBlockSize(int vec_size, uint64_t dim_size) {
inline int CalcBlockSize(int vec_size, uint64_t dim_size) {
uint64_t block_size = 1;
uint64_t max_block_size =
std::min(dim_size / vec_size, static_cast<uint64_t>(1024));
Expand Down Expand Up @@ -461,14 +438,11 @@ __device__ __forceinline__ void ThreadVecWrite(T* out,
}
}

template <typename T,
typename AccT,
typename IndexType,
int BatchSize,
int VecSize,
bool LogMode = false>
template <typename T, typename AccT, typename IndexType, bool LogMode = false>
__global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
using VecT = phi::AlignedVector<T, VecSize>;
constexpr int kVecSize =
MaxWithOne<MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T)>::kValue;
using VecT = phi::AlignedVector<T, kVecSize>;

int bid = blockIdx.x;
T* batch_input = const_cast<T*>(src) + bid * dim_size;
Expand All @@ -480,16 +454,16 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
((uint64_t)batch_output) % MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);

// get max value
AccT thread_max = ThreadVecReduce<MaxFunctor, T, AccT, VecSize>(
AccT thread_max = ThreadVecReduce<MaxFunctor, T, AccT, kVecSize>(
batch_input,
dim_size,
input_align_shift,
MaxFunctor<T, AccT>(),
std::numeric_limits<AccT>::min());
-std::numeric_limits<AccT>::infinity());
BlockReduceMax<AccT>(&thread_max);

// get exp value and sum all
AccT thread_exp = ThreadVecReduce<SumExpFunctor, T, AccT, VecSize>(
AccT thread_exp = ThreadVecReduce<SumExpFunctor, T, AccT, kVecSize>(
batch_input,
dim_size,
input_align_shift,
Expand All @@ -501,19 +475,19 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
if (LogMode) {
LogSoftmaxForwardFunctor<AccT, T> reduction(thread_max, thread_exp);
if (input_align_shift == output_align_shift) {
ThreadVecWriteVec<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
ThreadVecWriteVec<LogSoftmaxForwardFunctor, T, AccT, kVecSize>(
batch_output, batch_input, dim_size, input_align_shift, reduction);
} else {
ThreadVecWrite<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
ThreadVecWrite<LogSoftmaxForwardFunctor, T, AccT, kVecSize>(
batch_output, batch_input, dim_size, reduction);
}
} else {
SoftmaxForwardFunctor<AccT, T> reduction(thread_max, thread_exp);
if (input_align_shift == output_align_shift) {
ThreadVecWriteVec<SoftmaxForwardFunctor, T, AccT, VecSize>(
ThreadVecWriteVec<SoftmaxForwardFunctor, T, AccT, kVecSize>(
batch_output, batch_input, dim_size, input_align_shift, reduction);
} else {
ThreadVecWrite<SoftmaxForwardFunctor, T, AccT, VecSize>(
ThreadVecWrite<SoftmaxForwardFunctor, T, AccT, kVecSize>(
batch_output, batch_input, dim_size, reduction);
}
}
Expand Down Expand Up @@ -785,9 +759,9 @@ void SwitchWarpSoftmaxForward(const IndexType blocks,
const IndexType batch_size,
const IndexType stride,
const IndexType element_count,
IndexType Log2Elements) {
IndexType log2_element_count) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
switch (Log2Elements) {
switch (log2_element_count) {
SOFTMAX_WARP_FORWARD_CASE(0, AccT);
SOFTMAX_WARP_FORWARD_CASE(1, AccT);
SOFTMAX_WARP_FORWARD_CASE(2, AccT);
Expand All @@ -800,6 +774,10 @@ void SwitchWarpSoftmaxForward(const IndexType blocks,
SOFTMAX_WARP_FORWARD_CASE(9, AccT);
SOFTMAX_WARP_FORWARD_CASE(10, AccT);
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported softmax dim: element_count=%d, log2_element_count=%d!",
element_count,
log2_element_count));
break;
}
}
Expand All @@ -824,9 +802,9 @@ void SwitchWarpSoftmaxBackward(const int blocks,
const int batch_size,
const int stride,
const int element_count,
int Log2Elements) {
int log2_element_count) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
switch (Log2Elements) {
switch (log2_element_count) {
SOFTMAX_WARP_BACKWARD_CASE(0, AccT);
SOFTMAX_WARP_BACKWARD_CASE(1, AccT);
SOFTMAX_WARP_BACKWARD_CASE(2, AccT);
Expand All @@ -839,6 +817,9 @@ void SwitchWarpSoftmaxBackward(const int blocks,
SOFTMAX_WARP_BACKWARD_CASE(9, AccT);
SOFTMAX_WARP_BACKWARD_CASE(10, AccT);
default:
// PADDLE_THROW(phi::errors::Unimplemented(
// "Unsupported softmax dim: element_count=%d,
// log2_element_count=%d!", element_count, log2_element_count));
break;
}
}
Expand Down Expand Up @@ -1202,24 +1183,11 @@ template <typename T, typename IndexType, bool LogMode>
void LaunchKeMatrixSoftmaxForwardKernel(
const GPUContext& dev_ctx, T* out, const T* input, int N, int dim_size) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
const int vec_size = MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);
switch (getBlockSize(vec_size, dim_size)) {
FIXED_BLOCK_DIM(switch (vec_size) {
FIXED_VEC_SIZE(
KeMatrixSoftmaxForward<T,
AccT,
IndexType,
kBlockDim,
VecSize,
LogMode>
<<<N, kBlockDim, 0, dev_ctx.stream()>>>(out, input, dim_size));
default:
break;
});
default:
PADDLE_THROW(
errors::Fatal("the input dim has error in the softmax cuda kernel."));
}
constexpr int kVecSize =
MaxWithOne<MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T)>::kValue;
int block_dim = CalcBlockSize(kVecSize, dim_size);
KeMatrixSoftmaxForward<T, AccT, IndexType, LogMode>
<<<N, block_dim, 0, dev_ctx.stream()>>>(out, input, dim_size);
}

#if CUDNN_VERSION < 8100
Expand Down Expand Up @@ -1450,9 +1418,5 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
}
}
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
#undef FIXED_VEC_SIZE_BASE
#undef FIXED_VEC_SIZE

} // namespace phi