Skip to content

Commit

Permalink
fix bug of reduce_sum when src_dtype != dst_dtype and reduce_num == 1 (
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang authored Sep 30, 2021
1 parent 87cc8d4 commit e8efba5
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions paddle/fluid/operators/reduce_ops/reduce_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"
Expand Down Expand Up @@ -705,8 +706,16 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,

if (config.reduce_num == 1) {
auto out_dims = y->dims();
framework::TensorCopy(x, y->place(), y);
y->Resize(out_dims);
if (x.type() == y->type()) {
framework::TensorCopy(x, y->place(), y);
y->Resize(out_dims);
} else {
auto* dev_ctx = static_cast<platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(x.place()));
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>(y->type()),
CastOpFunctor<platform::CUDADeviceContext, Tx>(&x, y, *dev_ctx));
}
return;
}

Expand Down

0 comments on commit e8efba5

Please sign in to comment.