Skip to content

Commit

Permalink
[PTen]Separate origin Kernel and add Kernel for C++ API (#39002)
Browse files Browse the repository at this point in the history
* add kernel for c++ api

* fix compile bugs

* fix kunlun compile bugs

* perfect cmake

* fix compile bugs when run ci-inference

* fix compile bugs

* add non-raw kernel for fluid op

* fix compile bugs

* fix compile bugs

* fix unit test bug
  • Loading branch information
YuanRisheng authored Jan 21, 2022
1 parent 854a7ab commit a0f586b
Show file tree
Hide file tree
Showing 18 changed files with 453 additions and 190 deletions.
61 changes: 39 additions & 22 deletions cmake/pten_kernel.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -103,38 +103,55 @@ function(kernel_library TARGET)
list(LENGTH gpu_srcs gpu_srcs_len)
list(LENGTH xpu_srcs xpu_srcs_len)

if (${common_srcs_len} GREATER 0)
# If the kernel has a device independent public implementation,
# we will use this implementation and will not adopt the implementation
# under specific devices
# Build Target according different src organization
if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR
${xpu_srcs_len} GREATER 0) AND ${common_srcs_len} GREATER 0)
# If the common_srcs depends on specific device srcs, build target using this rule.
if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
endif()
elseif (WITH_ROCM)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
endif()
else()
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
cc_library(${TARGET}_part SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
endif()
endif()
elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
if (WITH_GPU)
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
else()
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
endif()
else()
# If the kernel has a header file declaration, but no corresponding
# implementation can be found, this is not allowed
if (${cpu_srcs_len} EQUAL 0 AND ${gpu_srcs_len} EQUAL 0 AND
${xpu_srcs_len} EQUAL 0)
message(FATAL_ERROR "Cannot find any implementation for ${TARGET}")
if (${common_srcs_len} EQUAL 0)
message(FATAL_ERROR "Cannot find any implementation for ${TARGET}")
else()
# If the kernel has a device independent public implementation,
# we will use this implementation and will not adopt the implementation
# under specific devices
if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else()
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
endif()
endif()
endif()

if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/cholesky_solve_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> {
commonterm_for_range(commonterm_functor);
commonterm_conj = helper.Transpose(commonterm_conj);

pten::AddKernel<T>(
pten::AddRawKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
commonterm, commonterm_conj, -1, &commonterm);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/elementwise/elementwise_add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::AddKernel<T>(
pten::AddRawKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/elementwise/elementwise_div_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ElementwiseDivKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::DivideKernel<T>(
pten::DivideRawKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_mul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y_lod);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod);
pten::MultiplyKernel<T>(cuda_ctx, *pt_x.get(), *pt_y.get(), axis,
pt_z.get());
pten::MultiplyRawKernel<T>(cuda_ctx, *pt_x.get(), *pt_y.get(), axis,
pt_z.get());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be "
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/elementwise/elementwise_mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod);
pten::MultiplyKernel<T>(
pten::MultiplyRawKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
Expand Down
24 changes: 20 additions & 4 deletions paddle/fluid/operators/elementwise/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,42 @@ class ElementwiseOp : public framework::OperatorWithKernel {

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
int axis = ctx.Attr<int>("axis");
if (Type() == "elementwise_add") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature("add", {"X", "Y"}, {"axis"}, {"Out"});
if (axis == -1) {
return framework::KernelSignature("add", {"X", "Y"}, {}, {"Out"});
}
return framework::KernelSignature("add_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_sub") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature("subtract", {"X", "Y"}, {"axis"},
if (axis == -1) {
return framework::KernelSignature("subtract", {"X", "Y"}, {},
{"Out"});
}
return framework::KernelSignature("subtract_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_div") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature("divide", {"X", "Y"}, {"axis"},
if (axis == -1) {
return framework::KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
}
return framework::KernelSignature("divide_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_mul") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature("multiply", {"X", "Y"}, {"axis"},
if (axis == -1) {
return framework::KernelSignature("multiply", {"X", "Y"}, {},
{"Out"});
}
return framework::KernelSignature("multiply_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/elementwise/elementwise_sub_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::SubtractKernel<T>(
pten::SubtractRawKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/lu_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1,
out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace());

pten::AddKernel<
pten::AddRawKernel<
T, typename paddle::framework::ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
Expand All @@ -234,7 +234,7 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1,
out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace());

pten::SubtractKernel<
pten::SubtractRawKernel<
T, typename paddle::framework::ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
Expand Down
13 changes: 11 additions & 2 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,17 +551,26 @@ class ReduceOp : public framework::OperatorWithKernel {

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
bool reduce_all = ctx.Attr<bool>("reduce_all");
if (Type() == "reduce_sum") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (!reduce_all) {
return framework::KernelSignature(
"sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"});
}
return framework::KernelSignature(
"sum", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"},
"sum_raw", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
}
}
if (Type() == "reduce_mean") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (!reduce_all) {
return framework::KernelSignature("mean", {"X"}, {"dim", "keep_dim"},
{"Out"});
}
return framework::KernelSignature(
"mean", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
}
// TODO(chentianyu03): support other cases after selected rows added
Expand Down
6 changes: 0 additions & 6 deletions paddle/pten/api/include/kernel_signature.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ using DeviceContext = paddle::platform::DeviceContext;
using add_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
int,
DenseTensor*);

using cast_kernel = void (*)(const DeviceContext&,
Expand All @@ -46,7 +45,6 @@ using concat_kernel = void (*)(const DeviceContext&,
using divide_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
int,
DenseTensor*);

using dot_kernel = void (*)(const DeviceContext&,
Expand Down Expand Up @@ -82,13 +80,11 @@ using mean_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const std::vector<int64_t>&,
bool,
bool,
DenseTensor*);

using multiply_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
int,
DenseTensor*);

using reshape_kernel = void (*)(const DeviceContext&,
Expand All @@ -107,14 +103,12 @@ using sum_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const std::vector<int64_t>&,
bool,
bool,
DataType,
DenseTensor*);

using subtract_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
int,
DenseTensor*);

using conj_kernel = void (*)(const DeviceContext&,
Expand Down
12 changes: 6 additions & 6 deletions paddle/pten/core/kernel_alias_name.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ namespace pten {
// the key is kernel_name in fluid, the value is the kernel_name in pten
// the key is sorted by key's alphabet
const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"elementwise_add", "add"},
{"elementwise_div", "divide"},
{"elementwise_mul", "muliply"},
{"elementwise_sub", "subtract"},
{"elementwise_add", "add_raw"},
{"elementwise_div", "divide_raw"},
{"elementwise_mul", "muliply_raw"},
{"elementwise_sub", "subtract_raw"},
{"fill_any_like", "full_like"},
{"fill_constant", "full"},
{"flatten_contiguous_range", "flatten"},
Expand All @@ -32,8 +32,8 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"matmul_v2_grad", "matmul_grad"},
{"matmul_v2_grad_grad", "matmul_double_grad"},
{"matmul_v2_triple_grad", "matmul_triple_grad"},
{"reduce_mean", "mean"},
{"reduce_sum", "sum"},
{"reduce_mean", "mean_raw"},
{"reduce_sum", "sum_raw"},
{"reshape2", "reshape"},
{"reshape2_grad", "reshape_grad"},
{"reshape2_grad_grad", "reshape_double_grad"},
Expand Down
Loading

0 comments on commit a0f586b

Please sign in to comment.