Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix behavior of put_along_axis and take_along_axis 易用性提升No.43 #59163

Merged
merged 14 commits into from
Dec 4, 2023
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
value_grad->Resize(index.dims());
dev_ctx.template Alloc<T>(value_grad);
if (index_type == DataType::INT32) {
phi::funcs::cpu_gather_kernel<T, int32_t>(
phi::funcs::cpu_scatter_value_grad_kernel<T, int32_t>(
out_grad, axis, index, *value_grad, dev_ctx);
} else if (index_type == DataType::INT64) {
phi::funcs::cpu_gather_kernel<T, int64_t>(
phi::funcs::cpu_scatter_value_grad_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
}
Expand Down
100 changes: 88 additions & 12 deletions paddle/phi/kernels/funcs/gather_scatter_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ struct cpu_gather_scatter_functor {
}
int64_t select_dim_size = index_dims[dim];
// index matrix has different shape with self matrix or src matrix.
int replaced_select_dim_size =
is_scatter_like ? self_dims[dim] : src_dims[dim];
int self_select_dim_size = self_dims[dim];
int src_select_dim_size = src_dims[dim];
int64_t outer_dim_size_self = 1;
int64_t outer_dim_size_src = 1;
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
for (int i = 0; i < dim; ++i) {
Expand All @@ -90,10 +92,10 @@ struct cpu_gather_scatter_functor {

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
outer_dim_size_self *= self_dims[i];
outer_dim_size_src *= src_dims[i];
}
int64_t index_idx = 0;
int64_t self_idx = 0, src_idx = 0;

// N layer loop squeezed into 3 layers loop
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < select_dim_size; j++) {
Expand All @@ -117,13 +119,21 @@ struct cpu_gather_scatter_functor {

// This index might out of bound of index matrix's index, so here
// multiply the replaced_select_dim_size.
int64_t replace_index = k + index * outer_dim_size +
i * outer_dim_size * replaced_select_dim_size;
int64_t replace_index_self, replace_index_src;
if (is_scatter_like) {
replace_index_self = k + index * outer_dim_size_self +
i * outer_dim_size_self * self_select_dim_size;

replace_index_src = k + j * outer_dim_size_src +
i * outer_dim_size_src * src_select_dim_size;
} else {
replace_index_self = index_idx;

self_idx = is_scatter_like ? replace_index : index_idx;
src_idx = is_scatter_like ? index_idx : replace_index;
reduce_op((tensor_t*)(self_data + self_idx), // NOLINT
(tensor_t*)(src_data + src_idx)); // NOLINT
replace_index_src = k + index * outer_dim_size_src +
i * outer_dim_size_src * src_select_dim_size;
}
reduce_op((tensor_t*)(self_data + replace_index_self), // NOLINT
(tensor_t*)(src_data + replace_index_src)); // NOLINT
index_idx++;
}
}
Expand Down Expand Up @@ -193,6 +203,7 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED,

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
int64_t outer_dim_size_data = 1;
int64_t select_dim_size = index_dims[dim];
int64_t output_select_dim_size = output_dims[dim];
for (int i = 0; i < dim; ++i) {
Expand All @@ -201,27 +212,92 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED,

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
outer_dim_size_data *= output_dims[i];
}

int64_t index_idx = 0;
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < select_dim_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = index_data[index_idx];
int64_t replace_index = k + index * outer_dim_size +
i * outer_dim_size * output_select_dim_size;
int64_t replace_index =
k + index * outer_dim_size_data +
i * outer_dim_size_data * output_select_dim_size;
output_data[replace_index] = 0;
index_idx++;
}
}
}
}

template <typename tensor_t, typename index_t>
void cpu_scatter_value_grad_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor output,
const phi::DeviceContext& ctx UNUSED) {
auto* self_data = self.data<tensor_t>();
auto* index_data = index.data<index_t>();
auto* output_data = output.data<tensor_t>();

auto index_dims = index.dims();
auto self_dims = self.dims();
auto output_dims = output.dims();

int64_t self_size = self.numel();
bool* is_self_grad_used = new bool[self_size];

for (int i = 0; i < self_size; i++) {
is_self_grad_used[i] = false;
}

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
int64_t outer_dim_size_self = 1;
int64_t outer_dim_size_output = 1;
int64_t select_dim_size = index_dims[dim];
int64_t self_select_dim_size = self_dims[dim];
int64_t output_select_dim_size = output_dims[dim];
for (int i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
outer_dim_size_self *= self_dims[i];
outer_dim_size_output *= output_dims[i];
}

int64_t index_idx = index.numel() - 1;
for (int64_t i = inner_dim_size - 1; i >= 0; i--) {
for (int64_t j = select_dim_size - 1; j >= 0; j--) {
for (int64_t k = outer_dim_size - 1; k >= 0; k--) {
int64_t index = index_data[index_idx];
int64_t replace_index_self =
k + index * outer_dim_size_self +
i * outer_dim_size_self * self_select_dim_size;
int64_t replace_index_output =
k + j * outer_dim_size_output +
i * outer_dim_size_output * output_select_dim_size;
if (!is_self_grad_used[replace_index_self]) {
output_data[replace_index_output] = self_data[replace_index_self];
is_self_grad_used[replace_index_self] = true;
} else {
output_data[replace_index_output] = 0;
}
index_idx--;
}
}
}
delete[] is_self_grad_used;
}

Instantiate_Template_Function(cpu_gather_kernel)
Instantiate_Template_Function(cpu_scatter_assign_kernel)
Instantiate_Template_Function(cpu_scatter_add_kernel)
Instantiate_Template_Function(cpu_scatter_mul_kernel)
Instantiate_Template_Function(cpu_scatter_input_grad_kernel)
Instantiate_Template_Function(cpu_scatter_value_grad_kernel)

} // namespace funcs
} // namespace phi
Loading