Skip to content

Commit

Permalink
add non-raw kernel for fluid op
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanRisheng committed Jan 21, 2022
1 parent 9ef36c0 commit f1f8874
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
15 changes: 15 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,41 @@ class ElementwiseOp : public framework::OperatorWithKernel {

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
int axis = ctx.Attr<int>("axis");
if (Type() == "elementwise_add") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("add", {"X", "Y"}, {}, {"Out"});
}
return framework::KernelSignature("add_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_sub") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("subtract", {"X", "Y"}, {},
{"Out"});
}
return framework::KernelSignature("subtract_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_div") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
}
return framework::KernelSignature("divide_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_mul") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("multiply", {"X", "Y"}, {},
{"Out"});
}
return framework::KernelSignature("multiply_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,15 +551,24 @@ class ReduceOp : public framework::OperatorWithKernel {

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
bool reduce_all = ctx.Attr<int>("reduce_all");
if (Type() == "reduce_sum") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (!reduce_all) {
return framework::KernelSignature(
"sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"});
}
return framework::KernelSignature(
"sum_raw", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
}
}
if (Type() == "reduce_mean") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (!reduce_all) {
return framework::KernelSignature("mean", {"X"}, {"dim", "keep_dim"},
{"Out"});
}
return framework::KernelSignature(
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
Expand Down

0 comments on commit f1f8874

Please sign in to comment.