Skip to content

Commit

Permalink
Merge branch 'put_along_axis' of https://github.com/YibinLiu666/Paddle
Browse files Browse the repository at this point in the history
…into put_along_axis2

merge
  • Loading branch information
YibinLiu666 committed Dec 2, 2023
2 parents 67933a4 + 28aadd6 commit 7ac73eb
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 3 deletions.
18 changes: 18 additions & 0 deletions paddle/phi/kernels/funcs/gather_scatter_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,25 +261,30 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor grad,
phi::DenseTensor grad,
const phi::DeviceContext& ctx UNUSED) {
auto* index_data = index.data<index_t>();
auto* grad_data = grad.data<tensor_t>();
auto* grad_data = grad.data<tensor_t>();

auto index_dims = index.dims();
auto grad_dims = grad.dims();
auto grad_dims = grad.dims();

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
int64_t outer_dim_size_data = 1;
int64_t select_dim_size = index_dims[dim];
int64_t grad_select_dim_size = grad_dims[dim];
int64_t grad_select_dim_size = grad_dims[dim];
for (int i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
outer_dim_size_data *= grad_dims[i];
outer_dim_size_data *= grad_dims[i];
}

int64_t index_idx = 0;
Expand Down Expand Up @@ -409,17 +414,21 @@ void cpu_scatter_value_grad_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor grad,
phi::DenseTensor grad,
const phi::DeviceContext& ctx UNUSED) {
auto* self_data = self.data<tensor_t>();
auto* index_data = index.data<index_t>();
auto* grad_data = grad.data<tensor_t>();
auto* grad_data = grad.data<tensor_t>();

auto index_dims = index.dims();
auto self_dims = self.dims();
auto grad_dims = grad.dims();
auto grad_dims = grad.dims();

int64_t self_size = self.numel();
int64_t grad_size = grad.numel();
int64_t grad_size = grad.numel();
bool* is_self_grad_used = new bool[self_size];

for (int i = 0; i < self_size; i++) {
Expand All @@ -430,9 +439,11 @@ void cpu_scatter_value_grad_kernel(phi::DenseTensor self,
int64_t outer_dim_size = 1;
int64_t outer_dim_size_self = 1;
int64_t outer_dim_size_grad = 1;
int64_t outer_dim_size_grad = 1;
int64_t select_dim_size = index_dims[dim];
int64_t self_select_dim_size = self_dims[dim];
int64_t grad_select_dim_size = grad_dims[dim];
int64_t grad_select_dim_size = grad_dims[dim];
for (int i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}
Expand All @@ -446,6 +457,9 @@ void cpu_scatter_value_grad_kernel(phi::DenseTensor self,
grad_data[i] = static_cast<tensor_t>(0);
}
int64_t index_idx = index.numel() - 1;
for (int i = 0; i < grad_size; i++) {
grad_data[i] = static_cast<tensor_t>(0);
}
for (int64_t i = inner_dim_size - 1; i >= 0; i--) {
for (int64_t j = select_dim_size - 1; j >= 0; j--) {
for (int64_t k = outer_dim_size - 1; k >= 0; k--) {
Expand All @@ -456,7 +470,11 @@ void cpu_scatter_value_grad_kernel(phi::DenseTensor self,
int64_t replace_index_grad =
k + j * outer_dim_size_grad +
i * outer_dim_size_grad * grad_select_dim_size;
int64_t replace_index_grad =
k + j * outer_dim_size_grad +
i * outer_dim_size_grad * grad_select_dim_size;
if (!is_self_grad_used[replace_index_self]) {
grad_data[replace_index_grad] = self_data[replace_index_self];
grad_data[replace_index_grad] = self_data[replace_index_self];
is_self_grad_used[replace_index_self] = true;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/gather_scatter_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ template <typename tensor_t, typename index_t>
void cpu_scatter_value_grad_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor output,
phi::DenseTensor grad,
const phi::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
Expand Down
5 changes: 4 additions & 1 deletion test/legacy_test/prim_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,10 @@ def parse_attri_value(name, op_inputs, op_proto_attrs):
tmp = input_arguments[idx_of_op_proto_arguments]
idx_of_op_proto_arguments += 1
else:
tmp = Empty() # use the default value
# tmp = Empty() # use the default value
tmp = parse_attri_value(
arg_name, op_proto_ins, op_proto_attrs
)

if isinstance(tmp, Empty):
results.append(get_default(idx, api_defaults))
Expand Down
42 changes: 41 additions & 1 deletion test/legacy_test/test_put_along_axis_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,52 @@ def init_data(self):
self.x_shape = (10, 10, 10)
self.value_type = "float16"
self.value = np.array([99]).astype(self.value_type)
self.index_type = "int32"
self.index_type = "int64"
self.index = np.array([[[0]]]).astype(self.index_type)
self.axis = 1
self.axis_type = "int64"


class TestPutAlongAxisOpCase2(TestPutAlongAxisOp):
def setUp(self):
self.init_data()
self.reduce_op = "assign"
self.op_type = "put_along_axis"
self.python_api = paddle.tensor.put_along_axis
self.xnp = np.random.random(self.x_shape).astype(self.x_type)
# numpy put_along_axis is an inplace operation.
self.target = copy.deepcopy(self.xnp)
for i in range(5):
for j in range(5):
for k in range(5):
self.target[i, self.index[i, j, k], k] = self.value[i, j, k]
self.inputs = {
'Input': self.xnp,
'Index': self.index,
'Value': self.value,
}
self.attrs = {
'Axis': self.axis,
'Reduce': self.reduce_op,
'include_self': True,
'broadcast': False,
}
self.outputs = {'Result': self.target}

def init_data(self):
self.dtype = 'float32'
self.x_type = "float32"
self.x_shape = (10, 10, 10)
self.value_type = "float32"
self.value = (
np.arange(1, 126).reshape((5, 5, 5)).astype(self.value_type)
)
self.index_type = "int64"
self.index = np.zeros((5, 5, 5)).astype(self.index_type)
self.axis = 1
self.axis_type = "int64"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
Expand Down
46 changes: 46 additions & 0 deletions test/legacy_test/test_take_along_axis_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,34 @@ def init_data(self):
self.axis_type = "int64"


class TestTakeAlongAxisOp(OpTest):
def setUp(self):
self.init_data()
self.op_type = "take_along_axis"
self.python_api = paddle.tensor.take_along_axis
self.check_cinn = True
self.xnp = np.random.random(self.x_shape).astype(self.x_type)
self.target = np.zeros((2, 3, 4)).astype(self.x_type)
for i in range(2):
for j in range(3):
for k in range(4):
self.target[i, j, k] = self.xnp[i, j, self.index[i, j, k]]
self.inputs = {
'Input': self.xnp,
'Index': self.index,
}
self.attrs = {'Axis': self.axis, 'broadcast': False}
self.outputs = {'Result': self.target}

def init_data(self):
self.x_type = "float64"
self.x_shape = (10, 10, 10)
self.index_type = "int64"
self.index = np.random.randint(0, 10, (2, 3, 4)).astype(self.index_type)
self.axis = 2
self.axis_type = "int64"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
Expand Down Expand Up @@ -252,6 +280,24 @@ def test_api_dygraph(self):
np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001)
paddle.enable_static()

def test_error(self):
paddle.disable_static(self.place[0])
tensorx = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]).astype("float32")
indices = paddle.to_tensor([1]).astype("int32")
# len(arr.shape) != len(indices.shape)
with self.assertRaises(ValueError):
res = paddle.take_along_axis(tensorx, indices, 0, False)
# the element of indices out of range
with self.assertRaises(RuntimeError):
indices = paddle.to_tensor([[100]]).astype("int32")
res = paddle.take_along_axis(tensorx, indices, 0, False)
# the shape of indices doesn't match
with self.assertRaises(RuntimeError):
indices = paddle.to_tensor(
[[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0]]
).astype("int32")
res = paddle.take_along_axis(tensorx, indices, 0, False)


if __name__ == "__main__":
paddle.enable_static()
Expand Down

0 comments on commit 7ac73eb

Please sign in to comment.