diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index feb2caa67b2f..94777991dffd 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -44,7 +44,7 @@
            'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
            'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
            'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
-           'diff', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
+           'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
            'where', 'bincount', 'pad']
 
 
@@ -6758,6 +6758,63 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):  # pylint: disable=redefin
     return _npi.diff(a, n=n, axis=axis)
 
 
+@set_module('mxnet.ndarray.numpy')
+def ediff1d(ary, to_end=None, to_begin=None):
+    """
+    The differences between consecutive elements of an array.
+
+    Parameters
+    ----------
+    ary : ndarray
+        If necessary, will be flattened before the differences are taken.
+    to_end : ndarray or scalar, optional
+        Number(s) to append at the end of the returned differences.
+    to_begin : ndarray or scalar, optional
+        Number(s) to prepend at the beginning of the returned differences.
+
+    Returns
+    -------
+    ediff1d : ndarray
+        The differences. Loosely, this is ``ary.flat[1:] - ary.flat[:-1]``.
+
+    Examples
+    --------
+    >>> x = np.array([1, 2, 4, 7, 0])
+    >>> np.ediff1d(x)
+    array([ 1.,  2.,  3., -7.])
+
+    >>> np.ediff1d(x, to_begin=-99, to_end=np.array([88, 99]))
+    rray([-99.,   1.,   2.,   3.,  -7.,  88.,  99.])
+
+    The returned array is always 1D.
+
+    >>> y = np.array([[1, 2, 4], [1, 6, 24]])
+    >>> np.ediff1d(y)
+    array([ 1.,  2., -3.,  5., 18.])
+
+    >>> np.ediff1d(x, to_begin=y)
+    array([ 1.,  2.,  4.,  1.,  6., 24.,  1.,  2.,  3., -7.])
+    """
+    from ...numpy import ndarray as np_ndarray
+    input_type = (isinstance(to_begin, np_ndarray), isinstance(to_end, np_ndarray))
+    # case 1: when both `to_begin` and `to_end` are arrays
+    if input_type == (True, True):
+        return _npi.ediff1d(ary, to_begin, to_end, to_begin_arr_given=True, to_end_arr_given=True,
+                            to_begin_scalar=None, to_end_scalar=None)
+    # case 2: only `to_end` is array but `to_begin` is scalar/None
+    elif input_type == (False, True):
+        return _npi.ediff1d(ary, to_end, to_begin_arr_given=False, to_end_arr_given=True,
+                            to_begin_scalar=to_begin, to_end_scalar=None)
+    # case 3: only `to_begin` is array but `to_end` is scalar/None
+    elif input_type == (True, False):
+        return _npi.ediff1d(ary, to_begin, to_begin_arr_given=True, to_end_arr_given=False,
+                            to_begin_scalar=None, to_end_scalar=to_end)
+    # case 4: both `to_begin` and `to_end` are scalar/None
+    else:
+        return _npi.ediff1d(ary, to_begin_arr_given=False, to_end_arr_given=False,
+                            to_begin_scalar=to_begin, to_end_scalar=to_end)
+
+
 @set_module('mxnet.ndarray.numpy')
 def resize(a, new_shape):
     """
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 382dbc0ea472..f6523674e1d9 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -65,7 +65,7 @@
            'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
            'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
            'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero',
-           'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'matmul',
+           'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
            'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount', 'pad']
 
 __all__ += fallback.__all__
@@ -8731,6 +8731,46 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):  # pylint: disable=redefin
     return _mx_nd_np.diff(a, n=n, axis=axis)
 
 
+@set_module('mxnet.numpy')
+def ediff1d(ary, to_end=None, to_begin=None):
+    """
+    The differences between consecutive elements of an array.
+
+    Parameters
+    ----------
+    ary : ndarray
+        If necessary, will be flattened before the differences are taken.
+    to_end : ndarray or scalar, optional
+        Number(s) to append at the end of the returned differences.
+    to_begin : ndarray or scalar, optional
+        Number(s) to prepend at the beginning of the returned differences.
+
+    Returns
+    -------
+    ediff1d : ndarray
+        The differences. Loosely, this is ``ary.flat[1:] - ary.flat[:-1]``.
+
+    Examples
+    --------
+    >>> x = np.array([1, 2, 4, 7, 0])
+    >>> np.ediff1d(x)
+    array([ 1.,  2.,  3., -7.])
+
+    >>> np.ediff1d(x, to_begin=-99, to_end=np.array([88, 99]))
+    rray([-99.,   1.,   2.,   3.,  -7.,  88.,  99.])
+
+    The returned array is always 1D.
+
+    >>> y = np.array([[1, 2, 4], [1, 6, 24]])
+    >>> np.ediff1d(y)
+    array([ 1.,  2., -3.,  5., 18.])
+
+    >>> np.ediff1d(x, to_begin=y)
+    array([ 1.,  2.,  4.,  1.,  6., 24.,  1.,  2.,  3., -7.])
+    """
+    return _mx_nd_np.ediff1d(ary, to_end=to_end, to_begin=to_begin)
+
+
 @set_module('mxnet.numpy')
 def resize(a, new_shape):
     """
diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py
index e21d20d2e2af..76146b1652a6 100644
--- a/python/mxnet/numpy_dispatch_protocol.py
+++ b/python/mxnet/numpy_dispatch_protocol.py
@@ -172,6 +172,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
     'quantile',
     'percentile',
     'diff',
+    'ediff1d',
     'resize',
     'where',
     'full_like',
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 46fbc7d1ff7e..66105b2da33f 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -49,7 +49,7 @@
            'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
            'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
            'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
-           'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff',
+           'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
            'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
            'where', 'bincount', 'pad']
 
@@ -6035,6 +6035,44 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):  # pylint: disable=redefin
     return _npi.diff(a, n=n, axis=axis)
 
 
+@set_module('mxnet.symbol.numpy')
+def ediff1d(ary, to_end=None, to_begin=None):
+    """
+    The differences between consecutive elements of an array.
+
+    Parameters
+    ----------
+    ary : _Symbol
+        If necessary, will be flattened before the differences are taken.
+    to_end : _Symbol or scalar, optional
+        Number(s) to append at the end of the returned differences.
+    to_begin : _Symbol or scalar, optional
+        Number(s) to prepend at the beginning of the returned differences.
+
+    Returns
+    -------
+    ediff1d : _Symbol
+        The differences. Loosely, this is ``ary.flat[1:] - ary.flat[:-1]``.
+    """
+    input_type = (isinstance(to_begin, _Symbol), isinstance(to_end, _Symbol))
+    # case 1: when both `to_begin` and `to_end` are arrays
+    if input_type == (True, True):
+        return _npi.ediff1d(ary, to_begin, to_end, to_begin_arr_given=True, to_end_arr_given=True,
+                            to_begin_scalar=None, to_end_scalar=None)
+    # case 2: only `to_end` is array but `to_begin` is scalar/None
+    elif input_type == (False, True):
+        return _npi.ediff1d(ary, to_end, to_begin_arr_given=False, to_end_arr_given=True,
+                            to_begin_scalar=to_begin, to_end_scalar=None)
+    # case 3: only `to_begin` is array but `to_end` is scalar/None
+    elif input_type == (True, False):
+        return _npi.ediff1d(ary, to_begin, to_begin_arr_given=True, to_end_arr_given=False,
+                            to_begin_scalar=None, to_end_scalar=to_end)
+    # case 4: both `to_begin` and `to_end` are scalar/None
+    else:
+        return _npi.ediff1d(ary, to_begin_arr_given=False, to_end_arr_given=False,
+                            to_begin_scalar=to_begin, to_end_scalar=to_end)
+
+
 @set_module('mxnet.symbol.numpy')
 def resize(a, new_shape):
     """
diff --git a/src/operator/numpy/np_diff.cc b/src/operator/numpy/np_diff.cc
index a3dae332d842..63847ef9dd94 100644
--- a/src/operator/numpy/np_diff.cc
+++ b/src/operator/numpy/np_diff.cc
@@ -43,8 +43,8 @@ inline TShape NumpyDiffShapeImpl(const TShape& ishape,
 }
 
 inline bool DiffShape(const nnvm::NodeAttrs& attrs,
-                        std::vector<TShape>* in_attrs,
-                        std::vector<TShape>* out_attrs) {
+                      std::vector<TShape>* in_attrs,
+                      std::vector<TShape>* out_attrs) {
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 1U);
   if (!shape_is_known(in_attrs->at(0))) {
@@ -57,8 +57,8 @@ inline bool DiffShape(const nnvm::NodeAttrs& attrs,
 }
 
 inline bool DiffType(const nnvm::NodeAttrs& attrs,
-                       std::vector<int>* in_attrs,
-                       std::vector<int>* out_attrs) {
+                     std::vector<int>* in_attrs,
+                     std::vector<int>* out_attrs) {
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 1U);
 
diff --git a/src/operator/numpy/np_ediff1d_op-inl.h b/src/operator/numpy/np_ediff1d_op-inl.h
new file mode 100644
index 000000000000..f1d5c2998000
--- /dev/null
+++ b/src/operator/numpy/np_ediff1d_op-inl.h
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_ediff1d-inl.h
+ * \brief Function definition of numpy-compatible ediff1d operator
+ */
+
+#ifndef MXNET_OPERATOR_NUMPY_NP_EDIFF1D_OP_INL_H_
+#define MXNET_OPERATOR_NUMPY_NP_EDIFF1D_OP_INL_H_
+
+#include <mxnet/base.h>
+#include <mxnet/operator_util.h>
+#include <vector>
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct EDiff1DParam : public dmlc::Parameter<EDiff1DParam> {
+  bool to_begin_arr_given, to_end_arr_given;
+  dmlc::optional<double> to_begin_scalar;
+  dmlc::optional<double> to_end_scalar;
+  DMLC_DECLARE_PARAMETER(EDiff1DParam) {
+    DMLC_DECLARE_FIELD(to_begin_arr_given)
+      .set_default(false)
+      .describe("To determine whether the `to_begin` parameter is an array.");
+    DMLC_DECLARE_FIELD(to_end_arr_given)
+      .set_default(false)
+      .describe("To determine whether the `to_end` parameter is an array.");
+    DMLC_DECLARE_FIELD(to_begin_scalar)
+      .set_default(dmlc::optional<double>())
+      .describe("If the `to_begin`is a scalar, the value of this parameter.");
+    DMLC_DECLARE_FIELD(to_end_scalar)
+      .set_default(dmlc::optional<double>())
+      .describe("If the `to_end`is a scalar, the value of this parameter.");
+  }
+};
+
+template<typename DType>
+struct set_to_val {
+  MSHADOW_XINLINE static void Map(index_t i, DType *out, double val) {
+    out[i] = DType(val);
+  }
+};
+
+template <typename DType>
+void copyArr(DType* dest, DType* src, size_t count,
+             mshadow::Stream<cpu> *s) {
+  memcpy(dest, src, count);
+}
+
+template <typename DType>
+void AssignScalar(DType* dest, index_t idx, double val,
+                  mshadow::Stream<cpu> *s) {
+  dest[idx] = DType(val);
+}
+
+#ifdef __CUDACC__
+template <typename DType>
+void copyArr(DType* dest, DType* src, size_t count,
+             mshadow::Stream<gpu> *s) {
+  CUDA_CALL(cudaMemcpyAsync(dest, src, count, cudaMemcpyDeviceToHost,
+                            mshadow::Stream<gpu>::GetStream(s)));
+}
+
+template <typename DType>
+void AssignScalar(DType* dest, index_t idx, double val,
+                  mshadow::Stream<gpu> *s) {
+  mxnet_op::Kernel<set_to_val<DType>, gpu>::Launch(s, 1, dest + idx, val);
+}
+#endif
+
+template<int req>
+struct ediff1d_forward {
+  template <typename DType>
+  MSHADOW_XINLINE static void Map(int i,
+                                  DType* out_data,
+                                  const DType* in_data,
+                                  const index_t padding) {
+    KERNEL_ASSIGN(out_data[i + padding], req, in_data[i + 1] - in_data[i]);
+  }
+};
+
+template<typename xpu>
+void EDiff1DForward(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_GE(inputs.size(), 1U);
+  CHECK_LE(inputs.size(), 3U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+    const EDiff1DParam& param = nnvm::get<EDiff1DParam>(attrs.parsed);
+    size_t padding = 0;
+    size_t in_size = (in_data.Size() > 0)? in_data.Size() - 1: 0;
+    index_t idx = 1;  // used to index the rest of input arrays
+
+    if (param.to_begin_arr_given) {
+      // if the `to_begin` parameter is an array, copy its values to the beginning of the out array
+      copyArr<DType>(out_data.dptr<DType>(), inputs[idx].dptr<DType>(),
+                      inputs[idx].Size() * sizeof(DType), s);
+      padding += inputs[idx].Size();
+      idx += 1;
+    } else if (param.to_begin_scalar.has_value()) {
+      // if the `to_begin` parameter is a scalar, directly assign its value
+      AssignScalar(out_data.dptr<DType>(), 0, param.to_begin_scalar.value(), s);
+      padding += 1;
+    }
+
+    if (param.to_end_arr_given) {
+      // if the `to_end` parameter is an array, copy its values to the end of the out array
+      copyArr<DType>(out_data.dptr<DType>() + padding + in_size,
+                     inputs[idx].dptr<DType>(), inputs[idx].Size() * sizeof(DType), s);
+    } else if (param.to_end_scalar.has_value()) {
+      // if the `to_end` parameter is a scalar, directly assign its value
+      AssignScalar(out_data.dptr<DType>(), padding + in_size, param.to_end_scalar.value(), s);
+    }
+
+    MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+      Kernel<ediff1d_forward<req_type>, xpu>::Launch(
+        s, in_size, out_data.dptr<DType>(), in_data.dptr<DType>(), padding);
+    });
+  });
+}
+
+template<int req>
+struct ediff1d_backward_arr {
+  template <typename DType>
+  MSHADOW_XINLINE static void Map(size_t i,
+                                  DType* igrad_dptr,
+                                  const DType* input_dptr,
+                                  const DType* ograd_dptr,
+                                  const size_t padding,
+                                  const size_t input_size) {
+    if (i == 0) {
+      KERNEL_ASSIGN(igrad_dptr[i], req, -ograd_dptr[i + padding]);
+    } else if (i == input_size - 1) {
+      KERNEL_ASSIGN(igrad_dptr[i], req, ograd_dptr[i - 1 + padding]);
+    } else {
+      KERNEL_ASSIGN(igrad_dptr[i], req, ograd_dptr[i - 1 + padding] - ograd_dptr[i + padding]);
+    }
+  }
+};
+
+template<typename xpu>
+void EDiff1DBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_GE(inputs.size(), 2U);
+  CHECK_LE(inputs.size(), 4U);
+  CHECK_GE(outputs.size(), 1U);
+  CHECK_LE(outputs.size(), 3U);
+  CHECK_EQ(req.size(), outputs.size());
+
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  const EDiff1DParam& param = nnvm::get<EDiff1DParam>(attrs.parsed);
+
+  const TBlob& ograd = inputs[0];
+  const TBlob& input = inputs[1];
+  const TBlob& igrad = outputs[0];
+  size_t in_size = (input.Size() > 0)? input.Size() - 1: 0;
+
+  MSHADOW_REAL_TYPE_SWITCH(ograd.type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+      size_t padding = 0;
+      index_t idx = 1;  // start from the second argument of `outputs`
+      if (param.to_begin_arr_given) {
+        copyArr<DType>(outputs[idx].dptr<DType>(),
+                       ograd.dptr<DType>(),
+                       outputs[idx].Size() * sizeof(DType), s);
+        padding += outputs[idx].Size();
+        idx += 1;
+      } else if (param.to_begin_scalar.has_value()) {
+        padding += 1;
+      }
+
+      if (param.to_end_arr_given) {
+        copyArr<DType>(outputs[idx].dptr<DType>(),
+                       ograd.dptr<DType>()+ in_size + padding,
+                       outputs[idx].Size() * sizeof(DType), s);
+      }
+
+      if (input.Size() == 0) return;
+      if (input.Size() == 1) {
+        Kernel<set_to_val<DType>, xpu>::Launch(s, 1, igrad.dptr<DType>(), 0);
+      } else {
+        Kernel<ediff1d_backward_arr<req_type>, xpu>::Launch(
+          s, igrad.Size(), igrad.dptr<DType>(),
+          input.dptr<DType>(), ograd.dptr<DType>(),
+          padding, igrad.Size());
+      }
+    });
+  });
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NUMPY_NP_EDIFF1D_OP_INL_H_
diff --git a/src/operator/numpy/np_ediff1d_op.cc b/src/operator/numpy/np_ediff1d_op.cc
new file mode 100644
index 000000000000..2fdcf7d082ca
--- /dev/null
+++ b/src/operator/numpy/np_ediff1d_op.cc
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_dediff1d_op.cc
+ * \brief CPU implementation of numpy-compatible ediff1d operator
+ */
+
+#include "./np_ediff1d_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+inline bool EDiff1DType(const nnvm::NodeAttrs& attrs,
+                        std::vector<int>* in_attrs,
+                        std::vector<int>* out_attrs) {
+  CHECK_GE(in_attrs->size(), 1U);
+  CHECK_LE(in_attrs->size(), 3U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+
+  const EDiff1DParam& param = nnvm::get<EDiff1DParam>(attrs.parsed);
+  if (param.to_begin_arr_given && param.to_end_arr_given) {
+      TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
+      TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(2));
+  } else if (param.to_begin_arr_given || param.to_end_arr_given) {
+      TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
+  }
+
+  TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+
+  return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
+}
+
+inline TShape NumpyEDiff1DShapeImpl(std::vector<TShape>* in_attrs,
+                                    const bool to_begin_arr_given,
+                                    const bool to_end_arr_given,
+                                    dmlc::optional<double> to_begin_scalar,
+                                    dmlc::optional<double> to_end_scalar) {
+  size_t out = (in_attrs->at(0).Size() > 0)? in_attrs->at(0).Size() - 1: 0;
+  // case 1: when both `to_begin` and `to_end` are arrays
+  if (to_begin_arr_given && to_end_arr_given) {
+      out += in_attrs->at(1).Size() + in_attrs->at(2).Size();
+  // case 2: only one of the parameters is an array
+  } else if (to_begin_arr_given || to_end_arr_given) {
+      out += in_attrs->at(1).Size();
+      // if the other one is a scalar
+      if (to_begin_scalar.has_value() || to_end_scalar.has_value()) {
+          out += 1;
+      }
+  // case 3: neither of the parameters is an array
+  } else {
+      // case 3.1: both of the parameters are scalars
+      if (to_begin_scalar.has_value() && to_end_scalar.has_value()) {
+          out += 2;
+      // case 3.2: only one of the parameters is a scalar
+      } else if (to_begin_scalar.has_value() || to_end_scalar.has_value()) {
+          out += 1;
+      }
+      // case 3.3: they are both `None` -- skip
+  }
+  TShape oshape = TShape(1, out);
+  return oshape;
+}
+
+inline bool EDiff1DShape(const nnvm::NodeAttrs& attrs,
+                         std::vector<TShape>* in_attrs,
+                         std::vector<TShape>* out_attrs) {
+  CHECK_GE(in_attrs->size(), 1U);
+  CHECK_LE(in_attrs->size(), 3U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  if (!shape_is_known(in_attrs->at(0))) {
+    return false;
+  }
+  const EDiff1DParam& param = nnvm::get<EDiff1DParam>(attrs.parsed);
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0,
+                     NumpyEDiff1DShapeImpl(in_attrs,
+                                           param.to_begin_arr_given,
+                                           param.to_end_arr_given,
+                                           param.to_begin_scalar,
+                                           param.to_end_scalar));
+  return shape_is_known(out_attrs->at(0));
+}
+
+DMLC_REGISTER_PARAMETER(EDiff1DParam);
+
+NNVM_REGISTER_OP(_npi_ediff1d)
+.set_attr_parser(ParamParser<EDiff1DParam>)
+.set_num_inputs(
+  [](const nnvm::NodeAttrs& attrs) {
+     const EDiff1DParam& param = nnvm::get<EDiff1DParam>(attrs.parsed);
+     int num_inputs = 1;
+     if (param.to_begin_arr_given) num_inputs += 1;
+     if (param.to_end_arr_given) num_inputs += 1;
+     return num_inputs;
+  })
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    const EDiff1DParam& param = nnvm::get<EDiff1DParam>(attrs.parsed);
+    int num_inputs = 1;
+    if (param.to_begin_arr_given) num_inputs += 1;
+    if (param.to_end_arr_given) num_inputs += 1;
+    if (num_inputs == 1) return std::vector<std::string>{"input1"};
+    if (num_inputs == 2) return std::vector<std::string>{"input1", "input2"};
+    return std::vector<std::string>{"input1", "input2", "input3"};
+  })
+.set_attr<mxnet::FInferShape>("FInferShape",  EDiff1DShape)
+.set_attr<nnvm::FInferType>("FInferType", EDiff1DType)
+.set_attr<FCompute>("FCompute<cpu>", EDiff1DForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_npi_backward_ediff1d"})
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.add_argument("input1", "NDArray-or-Symbol", "Source input")
+.add_argument("input2", "NDArray-or-Symbol", "Source input")
+.add_argument("input3", "NDArray-or-Symbol", "Source input")
+.add_arguments(EDiff1DParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_npi_backward_ediff1d)
+.set_attr_parser(ParamParser<EDiff1DParam>)
+.set_num_inputs(
+  [](const nnvm::NodeAttrs& attrs) {
+     const EDiff1DParam& param = nnvm::get<EDiff1DParam>(attrs.parsed);
+     int num_inputs = 2;
+     if (param.to_begin_arr_given) num_inputs += 1;
+     if (param.to_end_arr_given) num_inputs += 1;
+     return num_inputs;
+  })
+.set_num_outputs(
+  [](const nnvm::NodeAttrs& attrs) {
+     const EDiff1DParam& param = nnvm::get<EDiff1DParam>(attrs.parsed);
+     int num_inputs = 1;
+     if (param.to_begin_arr_given) num_inputs += 1;
+     if (param.to_end_arr_given) num_inputs += 1;
+     return num_inputs;
+  })
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<mxnet::FCompute>("FCompute<cpu>", EDiff1DBackward<cpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/np_ediff1d_op.cu b/src/operator/numpy/np_ediff1d_op.cu
new file mode 100644
index 000000000000..711d5e9b2498
--- /dev/null
+++ b/src/operator/numpy/np_ediff1d_op.cu
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_ediff1d_op.cu
+ * \brief GPU implementation of numpy-compatible ediff1d operator
+ */
+
+#include "./np_ediff1d_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_npi_ediff1d)
+.set_attr<FCompute>("FCompute<gpu>", EDiff1DForward<gpu>);
+
+NNVM_REGISTER_OP(_npi_backward_ediff1d)
+.set_attr<FCompute>("FCompute<gpu>", EDiff1DBackward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py
index 5a811f97f54f..8b5b18d6e914 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -1892,6 +1892,19 @@ def _add_workload_diff():
         OpArgMngr.add_workload('diff', x, n=n)
 
 
+def _add_workload_ediff1d():
+    x = np.array([1, 3, 6, 7, 1])
+    OpArgMngr.add_workload('ediff1d', x)
+    OpArgMngr.add_workload('ediff1d', x, 2, 4)
+    OpArgMngr.add_workload('ediff1d', x, x, 3)
+    OpArgMngr.add_workload('ediff1d', x, x, x)
+    OpArgMngr.add_workload('ediff1d', np.array([1.1, 2.2, 3.0, -0.2, -0.1]))
+    x = np.random.randint(5, size=(5, 0, 4))
+    OpArgMngr.add_workload('ediff1d', x)
+    OpArgMngr.add_workload('ediff1d', x, 2, 4)
+    OpArgMngr.add_workload('ediff1d', x, x, 3)
+    OpArgMngr.add_workload('ediff1d', x, x, x)
+
 def _add_workload_resize():
     OpArgMngr.add_workload('resize', np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.int32), (5, 1))
     OpArgMngr.add_workload('resize', np.eye(3), 3)
@@ -2840,6 +2853,7 @@ def _prepare_workloads():
     _add_workload_where()
     _add_workload_shape()
     _add_workload_diff()
+    _add_workload_ediff1d()
     _add_workload_quantile()
     _add_workload_percentile()
     _add_workload_resize()
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 45b64c26bb88..6b56fcbddb4e 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -7330,6 +7330,120 @@ def hybrid_forward(self, F, a):
                         assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
 
 
+@with_seed()
+@use_np
+def test_np_ediff1d():
+    def np_diff_backward(size, shape):
+        if size <= 1:
+            return _np.zeros(shape)
+        else:
+            ret = _np.ones(size - 1)
+            return _np.negative(_np.diff(ret, n=1, axis=-1, prepend=0, append=0)).reshape(shape)
+
+    # case 1: when both `to_begin` and `to_end` are arrays
+    class TestEDiff1DCASE1(HybridBlock):
+        def __init__(self):
+            super(TestEDiff1DCASE1, self).__init__()
+
+        def hybrid_forward(self, F, a, b, c):
+            return F.np.ediff1d(a, to_end=b, to_begin=c)
+
+    # case 2: only `to_end` is array but `to_begin` is scalar/None
+    class TestEDiff1DCASE2(HybridBlock):
+        def __init__(self, to_begin=None):
+            super(TestEDiff1DCASE2, self).__init__()
+            self._to_begin = to_begin
+
+        def hybrid_forward(self, F, a, b):
+            return F.np.ediff1d(a, to_end=b, to_begin=self._to_begin)
+
+    # case 3: only `to_begin` is array but `to_end` is scalar/None
+    class TestEDiff1DCASE3(HybridBlock):
+        def __init__(self, to_end=None):
+            super(TestEDiff1DCASE3, self).__init__()
+            self._to_end = to_end
+
+        def hybrid_forward(self, F, a, b):
+            return F.np.ediff1d(a, to_end=self._to_end, to_begin=b)
+
+    # case 4: both `to_begin` and `to_end` are scalar/None
+    class TestEDiff1DCASE4(HybridBlock):
+        def __init__(self, to_end=None, to_begin=None):
+            super(TestEDiff1DCASE4, self).__init__()
+            self._to_begin = to_begin
+            self._to_end = to_end
+
+        def hybrid_forward(self, F, a):
+            return F.np.ediff1d(a, to_end=self._to_end, to_begin=self._to_begin)
+
+    rtol = 1e-3
+    atol = 1e-5
+    mapper = {(True, True): TestEDiff1DCASE1,
+              (False, True): TestEDiff1DCASE2,
+              (True, False): TestEDiff1DCASE3,
+              (False, False): TestEDiff1DCASE4}
+    hybridize_list = [True, False]
+    shape_list = [(), (1,), (2, 3), 6, (7, 8), 10, (4, 0, 5)]
+    # dtype_list = [np.int32, np.int64, np.float16, np.float32, np.float64]
+    dtype_list = [np.float16, np.float32, np.float64]
+    append_list = [1, 2, None, (1, 2, 4), (4, 3), (), (5, 0), (6)]
+
+    for hybridize, dtype, shape, to_begin, to_end in itertools.product(hybridize_list, dtype_list,
+                shape_list, append_list, append_list):
+        mx_arr = np.random.randint(5, size=shape).astype(dtype)
+        np_arr = mx_arr.asnumpy()
+        kwargs = {}
+        mx_args = [mx_arr]
+        np_args = [np_arr]
+        mx_args_imperative = [mx_arr]
+
+        if isinstance(to_end, tuple):
+            to_end = np.random.randint(5, size=to_end).astype(dtype)
+            mx_args.append(to_end)
+            np_args.append(to_end.asnumpy())
+        else:
+            kwargs["to_end"] = to_end
+            np_args.append(to_end)
+        mx_args_imperative.append(to_end)
+
+        if isinstance(to_begin, tuple):
+            to_begin = np.random.randint(5, size=to_begin).astype(dtype)
+            mx_args.append(to_begin)
+            np_args.append(to_begin.asnumpy())
+        else:
+            kwargs["to_begin"] = to_begin
+            np_args.append(to_begin)
+        mx_args_imperative.append(to_begin)
+
+        from mxnet.numpy import ndarray as np_ndarray
+        input_type = (isinstance(to_begin, np_ndarray), isinstance(to_end, np_ndarray))
+        test_np_ediff1d = mapper[input_type](**kwargs)
+
+        if hybridize:
+            test_np_ediff1d.hybridize()
+
+        np_out = _np.ediff1d(*np_args)
+        for arg in mx_args:
+            arg.attach_grad()
+
+        with mx.autograd.record():
+            mx_out = test_np_ediff1d(*mx_args)
+        assert mx_out.shape == np_out.shape
+        assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol)
+        # test imperative 
+        mx_out_imperative = np.ediff1d(*mx_args_imperative)
+        assert mx_out_imperative.shape == np_out.shape
+        assert_almost_equal(mx_out_imperative.asnumpy(), np_out, atol=atol, rtol=rtol)
+
+        mx_out.backward()
+        if dtype in [np.float16, np.float32, np.float64]:
+            for idx, arg in enumerate(mx_args):
+                if idx == 0:
+                    assert_almost_equal(arg.grad.asnumpy(), np_diff_backward(arg.size, arg.shape), atol=atol, rtol=rtol)
+                else:
+                    assert_almost_equal(arg.grad.asnumpy(), np.ones_like(arg), atol=atol, rtol=rtol)
+
+
 @with_seed()
 @use_np
 def test_np_column_stack():