Skip to content

Commit

Permalink
[Phi] Move gather op kernel into phi (#40500)
Browse files Browse the repository at this point in the history
* add phi gather kernel

* update year

* remove original gather opkernel

* add gather grad phi kernels

* remove origin gather grad kernel

* fix failed npu and xpu

* fix xpu compile failed
  • Loading branch information
chenwhql authored Mar 15, 2022
1 parent dde9cec commit 0c703fe
Show file tree
Hide file tree
Showing 13 changed files with 405 additions and 302 deletions.
14 changes: 2 additions & 12 deletions paddle/fluid/operators/gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/gather_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"

Expand Down Expand Up @@ -198,17 +198,7 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
ops::GatherGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
ops::GatherGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<uint8_t>,
ops::GatherOpKernel<int64_t>,
ops::GatherOpKernel<phi::dtype::bfloat16>);
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
ops::GatherGradientOpKernel<double>,
ops::GatherGradientOpKernel<int>,
ops::GatherGradientOpKernel<uint8_t>,
ops::GatherGradientOpKernel<int64_t>,
ops::GatherGradientOpKernel<phi::dtype::bfloat16>);

REGISTER_OP_VERSION(gather)
.AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
Expand Down
152 changes: 0 additions & 152 deletions paddle/fluid/operators/gather_op.cu

This file was deleted.

133 changes: 0 additions & 133 deletions paddle/fluid/operators/gather_op.h

This file was deleted.

3 changes: 2 additions & 1 deletion paddle/fluid/operators/gather_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/gather_op.h"
#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/operators/gather_op_npu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/gather_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace f = paddle::framework;
namespace p = paddle::platform;

USE_OP(gather);
USE_OP_ITSELF(gather);
USE_OP_DEVICE_KERNEL(gather, NPU);
USE_OP(gather_grad);
USE_OP_ITSELF(gather_grad);
USE_OP_DEVICE_KERNEL(gather_grad, NPU);

template <typename T>
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/operators/gather_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/gather_op.h"
#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
class GatherOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
Expand Down
Loading

0 comments on commit 0c703fe

Please sign in to comment.