Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… adjust_output_name_parsing
  • Loading branch information
jim19930609 committed Mar 3, 2022
2 parents b03c79f + cac00e0 commit 504feac
Show file tree
Hide file tree
Showing 38 changed files with 921 additions and 495 deletions.
29 changes: 29 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,35 @@ class ProcessGroup {
"ProcessGroup%s does not support receive", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<Tensor>& in_tensors /* tensors */, // NOLINT
std::vector<Tensor>& out_tensors /* tensors */) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<Tensor>& in /* tensors */, // NOLINT
std::vector<Tensor>& out /* tensors */) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllToAll", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<Tensor>& tensors /* tensors */, // NOLINT
const ReduceOptions& opts) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support Reduce", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<Tensor>& in_tensors /* tensors */, // NOLINT
std::vector<Tensor>& out_tensors /* tensors */, // NOLINT
const ScatterOptions&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support Scatter", GetBackendName()));
}

protected:
const int rank_;
const int size_;
Expand Down
143 changes: 143 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,5 +473,148 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<Tensor>& in_tensors, std::vector<Tensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::ncclAllGather(
input_tensor->data(), output_tensor->data(), input_tensor->numel(),
platform::ToNCCLDataType(input.type()), comm, stream);
},
CommType::ALLGATHER);
}

void* GetPointerByOffset(void* raw_pointer, size_t offset,
experimental::DataType type) {
if (type == experimental::DataType::FLOAT32) {
return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::FLOAT64) {
return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::INT32) {
return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::INT64) {
return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::FLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
offset);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported."));
}
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
std::vector<Tensor>& in_tensors, std::vector<Tensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
size_t offset = 0;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input_tensor->data(), offset, input.type()),
input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), i, comm, stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
GetPointerByOffset(output_tensor->data(), offset, input.type()),
input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), i, comm, stream));
offset += input_tensor->numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
},
CommType::ALLREDUCE);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
std::vector<Tensor>& tensors, const ReduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
tensors, tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
input_tensor->data(), output_tensor->data(), input.numel(),
platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op), opts.root_rank, comm, stream));
},
CommType::REDUCE);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::vector<Tensor>& in_tensors, std::vector<Tensor>& out_tensors,
const ScatterOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
size_t offset = 0;
if (rank_ == opts.root_rank) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input_tensor->data(), offset, input.type()),
input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), i, comm, stream));
offset += input_tensor->numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output_tensor->data(), input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), opts.root_rank, comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output_tensor->data(), input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), opts.root_rank, comm,
stream));
}
},
CommType::SCATTER);
}

} // namespace distributed
} // namespace paddle
14 changes: 14 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ class ProcessGroupNCCL : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Recv(std::vector<Tensor>& tensors,
int src_rank) override;

std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<Tensor>& in_tensors,
std::vector<Tensor>& out_tensors) override;

std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<Tensor>& in, std::vector<Tensor>& out) override;

std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<Tensor>& tensors, const ReduceOptions& opts) override;

std::shared_ptr<ProcessGroup::Task> Scatter(std::vector<Tensor>& in_tensors,
std::vector<Tensor>& out_tensors,
const ScatterOptions&) override;

protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType,
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/distributed/collective/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,14 @@ struct BarrierOptions {
std::vector<int> place_ids;
};

struct ReduceOptions {
ReduceOp reduce_op = ReduceOp::SUM;
int root_rank = 0;
};

struct ScatterOptions {
int root_rank = 0;
};

} // namespace distributed
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/inference/tensorrt/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
return TRT_DT::kFLOAT;
case FluidDT::VarType_Type_INT32:
return TRT_DT::kINT32;
case FluidDT::VarType_Type_FP16:
return TRT_DT::kHALF;
default:
return TRT_DT::kINT32;
}
Expand Down
42 changes: 22 additions & 20 deletions paddle/fluid/operators/conv_cudnn_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#endif
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/kernels/funcs/padding.h"

DECLARE_bool(cudnn_deterministic);
DECLARE_uint64(conv_workspace_size_limit);
Expand Down Expand Up @@ -148,7 +148,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize);

int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);
bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);

Tensor transformed_input;
std::vector<int> padding_common(data_dim, 0);
Expand Down Expand Up @@ -196,13 +196,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_input_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_input_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
default:
Expand Down Expand Up @@ -488,7 +488,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// cuDNN only supports padding the same amount on every dimension.
// So we create a new padded input tensor.
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);
bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_input(input->type());
Tensor transformed_input_grad(input->type());
std::vector<int> padding_common(data_dim, 0);
Expand Down Expand Up @@ -544,13 +544,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_input_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_input_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
default:
Expand Down Expand Up @@ -956,7 +956,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
in_data_dims, strides, ksize);

int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);
bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_X(X->type());
Tensor transformed_ddX(X->type());

Expand Down Expand Up @@ -1004,20 +1004,22 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_X_channel, pad_value,
&transformed_X);
if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_ddX_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
dev_ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
}
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, transformed_X_channel, pad_value,
&transformed_X);
if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_ddX_channel, pad_value,
phi::funcs::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
dev_ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
}
} break;
Expand Down
Loading

1 comment on commit 504feac

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.