Skip to content

Commit

Permalink
add_n supports mixed dtype inputs (#59190)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu authored Nov 22, 2023
1 parent f2ae424 commit 8540b29
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
89 changes: 89 additions & 0 deletions paddle/phi/kernels/gpu/add_n_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,28 @@ __global__ void SumArrayCUDAKernel(
}
}

template <class T, class HALF>
__global__ void SumArrayMixedTypeCUDAKernel(const T *in_0,
void **in_others,
T *out,
int64_t N,
size_t in_others_size,
bool read_dst) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
MPType total(read_dst ? static_cast<MPType>(out[idx])
: static_cast<MPType>(0));
total += static_cast<MPType>(in_0[idx]);
for (int i = 0; i < in_others_size; ++i) {
const HALF *tmp = static_cast<HALF *>(in_others[i]);
if (tmp) {
total += static_cast<MPType>(tmp[idx]);
}
}
out[idx] = static_cast<T>(total);
}
}

template <class T>
__global__ void SumSelectedRowsCUDAKernel(T **sr_in_out,
int64_t N,
Expand Down Expand Up @@ -125,6 +147,73 @@ void AddNKernel(const Context &dev_ctx,
constant_functor(dev_ctx, out, static_cast<T>(0));
}

// Support mixed inputs for master grad accumulation
// conditions:
// 1. all inputs are DensorTensor and number >= 2
// 2. the first tensor is fp32 type and the others are fp16/bf16 type
if (in_num >= 2 && DenseTensor::classof(x[0]) &&
x[0]->dtype() == phi::DataType::FLOAT32 &&
x[1]->dtype() != phi::DataType::FLOAT32) {
auto in_other_dtype = x[1]->dtype();
int64_t numel = static_cast<const DenseTensor *>(x[0])->numel();
bool all_dense_tensor = true;
std::vector<const void *> in_data;
const T *in_0 = static_cast<const DenseTensor *>(x[0])->data<T>();
for (int i = 1; i < in_num; ++i) {
PADDLE_ENFORCE_EQ(
in_other_dtype,
x[i]->dtype(),
errors::InvalidArgument("The dtype of inputs should be the same, "
"but received the dtype of input 1 is %s, "
"input %d is %s",
i,
x[i]->dtype()));
if (DenseTensor::classof(x[i])) {
auto &in_i = *(static_cast<const DenseTensor *>(x[i]));
if (in_i.IsInitialized()) {
in_data.emplace_back(in_i.data());
}
} else {
all_dense_tensor = false;
break;
}
}

if (all_dense_tensor && (in_other_dtype == phi::DataType::BFLOAT16 ||
in_other_dtype == phi::DataType::FLOAT16)) {
auto tmp_in_array = phi::memory_utils::Alloc(
dev_ctx.GetPlace(), in_data.size() * sizeof(void *));
memory_utils::Copy(dev_ctx.GetPlace(),
tmp_in_array->ptr(),
phi::CPUPlace(),
reinterpret_cast<void *>(in_data.data()),
in_data.size() * sizeof(void *),
dev_ctx.stream());

void **in_array_data = reinterpret_cast<void **>(tmp_in_array->ptr());
ComputeKernelParameter(numel);
VLOG(4) << "Call SumArrayMixedTypeCUDAKernel";
if (in_other_dtype == phi::DataType::FLOAT16) {
SumArrayMixedTypeCUDAKernel<T, phi::float16>
<<<grids, blocks, 0, stream>>>(in_0,
in_array_data,
out->data<T>(),
numel,
in_data.size(),
in_place);
} else if (in_other_dtype == phi::DataType::BFLOAT16) {
SumArrayMixedTypeCUDAKernel<T, phi::bfloat16>
<<<grids, blocks, 0, stream>>>(in_0,
in_array_data,
out->data<T>(),
numel,
in_data.size(),
in_place);
}
return;
}
}

std::vector<const T *> in_data;
std::vector<int> selectrow_index;
int64_t lod_length = 0;
Expand Down
20 changes: 18 additions & 2 deletions test/legacy_test/test_add_n_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ def setUp(self):
self.l = 32
self.x_np = np.random.random([self.l, 16, 256])

def check_main(self, x_np, dtype, axis=None):
def check_main(self, x_np, dtype, axis=None, mixed_dtype=False):
paddle.disable_static()
x = []
for i in range(x_np.shape[0]):
val = paddle.to_tensor(x_np[i].astype(dtype))
if mixed_dtype and i == 0:
val = paddle.to_tensor(x_np[i].astype('float32'))
else:
val = paddle.to_tensor(x_np[i].astype(dtype))
val.stop_gradient = False
x.append(val)

y = paddle.add_n(x)
x_g = paddle.grad(y, x)
y_np = y.numpy().astype(dtype)
Expand All @@ -50,6 +54,18 @@ def test_add_n_fp16(self):
for i in range(len(x_g_np_32)):
np.testing.assert_allclose(x_g_np_16[i], x_g_np_32[i], rtol=1e-03)

def test_add_n_fp16_mixed_dtype(self):
if not paddle.is_compiled_with_cuda():
return
y_np_16, x_g_np_16 = self.check_main(
self.x_np, 'float16', mixed_dtype=True
)
y_np_32, x_g_np_32 = self.check_main(self.x_np, 'float32')

np.testing.assert_allclose(y_np_16, y_np_32, rtol=1e-03)
for i in range(len(x_g_np_32)):
np.testing.assert_allclose(x_g_np_16[i], x_g_np_32[i], rtol=1e-03)

def test_add_n_api(self):
if not paddle.is_compiled_with_cuda():
return
Expand Down

0 comments on commit 8540b29

Please sign in to comment.