diff --git a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc index d44af05357a9a4..dd7b762849d16b 100644 --- a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc @@ -60,10 +60,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, value_grad->Resize(index.dims()); dev_ctx.template Alloc(value_grad); if (index_type == DataType::INT32) { - phi::funcs::cpu_gather_kernel( + phi::funcs::cpu_scatter_value_grad_kernel( out_grad, axis, index, *value_grad, dev_ctx); - } else if (index_type == DataType::INT64) { - phi::funcs::cpu_gather_kernel( + } else { + phi::funcs::cpu_scatter_value_grad_kernel( out_grad, axis, index, *value_grad, dev_ctx); } } diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cc b/paddle/phi/kernels/funcs/gather_scatter_functor.cc index 597b8f231760bf..be07c68b0fd338 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cc +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cc @@ -80,8 +80,10 @@ struct cpu_gather_scatter_functor { } int64_t select_dim_size = index_dims[dim]; // index matrix has different shape with self matrix or src matrix. - int replaced_select_dim_size = - is_scatter_like ? self_dims[dim] : src_dims[dim]; + int self_select_dim_size = self_dims[dim]; + int src_select_dim_size = src_dims[dim]; + int64_t outer_dim_size_self = 1; + int64_t outer_dim_size_src = 1; int64_t inner_dim_size = 1; int64_t outer_dim_size = 1; for (int i = 0; i < dim; ++i) { @@ -90,10 +92,10 @@ struct cpu_gather_scatter_functor { for (int i = dim + 1; i < index_dims.size(); i++) { outer_dim_size *= index_dims[i]; + outer_dim_size_self *= self_dims[i]; + outer_dim_size_src *= src_dims[i]; } int64_t index_idx = 0; - int64_t self_idx = 0, src_idx = 0; - // N layer loop squeezed into 3 layers loop for (int64_t i = 0; i < inner_dim_size; i++) { for (int64_t j = 0; j < select_dim_size; j++) { @@ -117,13 +119,21 @@ struct cpu_gather_scatter_functor { // This index might out of bound of index matrix's index, so here // multiply the replaced_select_dim_size. - int64_t replace_index = k + index * outer_dim_size + - i * outer_dim_size * replaced_select_dim_size; + int64_t replace_index_self, replace_index_src; + if (is_scatter_like) { + replace_index_self = k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; + + replace_index_src = k + j * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } else { + replace_index_self = index_idx; - self_idx = is_scatter_like ? replace_index : index_idx; - src_idx = is_scatter_like ? index_idx : replace_index; - reduce_op((tensor_t*)(self_data + self_idx), // NOLINT - (tensor_t*)(src_data + src_idx)); // NOLINT + replace_index_src = k + index * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } + reduce_op((tensor_t*)(self_data + replace_index_self), // NOLINT + (tensor_t*)(src_data + replace_index_src)); // NOLINT index_idx++; } } @@ -183,24 +193,26 @@ template void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED, int dim, const phi::DenseTensor& index, - phi::DenseTensor output, + phi::DenseTensor grad, const phi::DeviceContext& ctx UNUSED) { auto* index_data = index.data(); - auto* output_data = output.data(); + auto* grad_data = grad.data(); auto index_dims = index.dims(); - auto output_dims = output.dims(); + auto grad_dims = grad.dims(); int64_t inner_dim_size = 1; int64_t outer_dim_size = 1; + int64_t outer_dim_size_data = 1; int64_t select_dim_size = index_dims[dim]; - int64_t output_select_dim_size = output_dims[dim]; + int64_t grad_select_dim_size = grad_dims[dim]; for (int i = 0; i < dim; ++i) { inner_dim_size *= index_dims[i]; } for (int i = dim + 1; i < index_dims.size(); i++) { outer_dim_size *= index_dims[i]; + outer_dim_size_data *= grad_dims[i]; } int64_t index_idx = 0; @@ -208,20 +220,84 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED, for (int64_t j = 0; j < select_dim_size; j++) { for (int64_t k = 0; k < outer_dim_size; k++) { int64_t index = index_data[index_idx]; - int64_t replace_index = k + index * outer_dim_size + - i * outer_dim_size * output_select_dim_size; - output_data[replace_index] = 0; + int64_t replace_index = k + index * outer_dim_size_data + + i * outer_dim_size_data * grad_select_dim_size; + grad_data[replace_index] = 0; index_idx++; } } } } +template +void cpu_scatter_value_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor grad, + const phi::DeviceContext& ctx UNUSED) { + auto* self_data = self.data(); + auto* index_data = index.data(); + auto* grad_data = grad.data(); + + auto index_dims = index.dims(); + auto self_dims = self.dims(); + auto grad_dims = grad.dims(); + + int64_t self_size = self.numel(); + int64_t grad_size = grad.numel(); + bool* is_self_grad_used = new bool[self_size]; + + for (int i = 0; i < self_size; i++) { + is_self_grad_used[i] = false; + } + + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + int64_t outer_dim_size_self = 1; + int64_t outer_dim_size_grad = 1; + int64_t select_dim_size = index_dims[dim]; + int64_t self_select_dim_size = self_dims[dim]; + int64_t grad_select_dim_size = grad_dims[dim]; + for (int i = 0; i < dim; ++i) { + inner_dim_size *= index_dims[i]; + } + + for (int i = dim + 1; i < index_dims.size(); i++) { + outer_dim_size *= index_dims[i]; + outer_dim_size_self *= self_dims[i]; + outer_dim_size_grad *= grad_dims[i]; + } + int64_t index_idx = index.numel() - 1; + for (int i = 0; i < grad_size; i++) { + grad_data[i] = static_cast(0); + } + for (int64_t i = inner_dim_size - 1; i >= 0; i--) { + for (int64_t j = select_dim_size - 1; j >= 0; j--) { + for (int64_t k = outer_dim_size - 1; k >= 0; k--) { + int64_t index = index_data[index_idx]; + int64_t replace_index_self = + k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; + int64_t replace_index_grad = + k + j * outer_dim_size_grad + + i * outer_dim_size_grad * grad_select_dim_size; + if (!is_self_grad_used[replace_index_self]) { + grad_data[replace_index_grad] = self_data[replace_index_self]; + is_self_grad_used[replace_index_self] = true; + } + index_idx--; + } + } + } + delete[] is_self_grad_used; +} + Instantiate_Template_Function(cpu_gather_kernel) Instantiate_Template_Function(cpu_scatter_assign_kernel) Instantiate_Template_Function(cpu_scatter_add_kernel) Instantiate_Template_Function(cpu_scatter_mul_kernel) Instantiate_Template_Function(cpu_scatter_input_grad_kernel) + Instantiate_Template_Function(cpu_scatter_value_grad_kernel) } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index b53de3beef9aa4..cbe866d4924d54 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -54,6 +54,79 @@ class ReduceMul { }; static ReduceMul reduce_mul; +template +__global__ void ScatterAssignGPUKernel(tensor_t* self_data, + int dim, + const index_t* index_data, + tensor_t* src_data, + int select_dim_size, + int self_select_dim_size, + int src_select_dim_size, + int64_t outer_dim_size, + int64_t outer_dim_size_self, + int64_t outer_dim_size_src, + int64_t numel, + int64_t numel_data, + const func_t& reduce_op) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + extern __shared__ int thread_ids[]; + + if (tid == 0) { + for (int i = 0; i < numel_data; i++) { + thread_ids[i] = 0; + } + } + __syncthreads(); + int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop + // squeezed from the N layers loop. + /* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */ + i = tid / (select_dim_size * outer_dim_size); + int64_t remind = tid % (select_dim_size * outer_dim_size); + j = remind / outer_dim_size; + k = remind % outer_dim_size; + index_t index = index_data[tid]; + /* + gather computation formula: + + self[i][j][k] = src[index[i][j][k]][j][k] # if dim == 0 + self[i][j][k] = src[i][index[i][j][k]][k] # if dim == 1 + self[i][j][k] = src[i][j][index[i][j][k]] # if dim == 2 + + scatter computation formula: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + */ + // index matrix has different shape with self matrix or src matrix. + int64_t replace_index_self, replace_index_src; + if (is_scatter_like) { + replace_index_self = k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; + + replace_index_src = k + j * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } else { + replace_index_self = tid; + + replace_index_src = k + index * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } + + atomicMax(thread_ids + replace_index_self, tid); + __syncthreads(); + + if (tid == thread_ids[replace_index_self]) { + reduce_op(static_cast(self_data + replace_index_self), + static_cast(src_data + replace_index_src)); + } +} + template = numel) return; @@ -93,12 +169,22 @@ __global__ void GatherScatterGPUKernel(tensor_t* self_data, */ // index matrix has different shape with self matrix or src matrix. - int64_t replace_index = k + index * outer_dim_size + - i * outer_dim_size * replaced_select_dim_size; - int64_t self_idx = is_scatter_like ? replace_index : tid; - int64_t src_idx = is_scatter_like ? tid : replace_index; - reduce_op(static_cast(self_data + self_idx), - static_cast(src_data + src_idx)); + int64_t replace_index_self, replace_index_src; + if (is_scatter_like) { + replace_index_self = k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; + + replace_index_src = k + j * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } else { + replace_index_self = tid; + + replace_index_src = k + index * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } + + reduce_op(static_cast(self_data + replace_index_self), + static_cast(src_data + replace_index_src)); } template (ctx).stream(); - GatherScatterGPUKernel - <<>>(self_data, - dim, - index_data, - src_data, - inner_dim_size, - select_dim_size, - replaced_select_dim_size, - outer_dim_size, - index_size, - reduce_op); + if (method_name == "scatter_assign_gpu") { + int shared_mem_size = + is_scatter_like ? sizeof(int) * self_size : sizeof(int) * index_size; + ScatterAssignGPUKernel + <<>>(self_data, + dim, + index_data, + src_data, + select_dim_size, + self_select_dim_size, + src_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_src, + index_size, + self_size, + reduce_op); + } else { + GatherScatterGPUKernel + <<>>(self_data, + dim, + index_data, + src_data, + select_dim_size, + self_select_dim_size, + src_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_src, + index_size, + self_size, + reduce_op); + } } }; // struct gpu_gather_scatter_functor @@ -211,11 +323,12 @@ template __global__ void ScatterInputGradGPUKernel(tensor_t* grad_data, int dim, const index_t* index_data, - int64_t inner_dim_size, int select_dim_size, int grad_select_dim_size, int64_t outer_dim_size, - int64_t numel) { + int64_t outer_dim_size_data, + int64_t numel, + int64_t numel_data) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; int64_t i, j, k; @@ -224,8 +337,9 @@ __global__ void ScatterInputGradGPUKernel(tensor_t* grad_data, j = remind / outer_dim_size; k = remind % outer_dim_size; index_t index = index_data[tid]; - int64_t replace_index = - k + index * outer_dim_size + i * outer_dim_size * grad_select_dim_size; + int64_t replace_index = k + index * outer_dim_size_data + + i * outer_dim_size_data * grad_select_dim_size; + grad_data[replace_index] = 0; } template @@ -240,9 +354,11 @@ void gpu_scatter_input_grad_kernel(phi::DenseTensor self, auto index_dims = index.dims(); auto grad_dims = grad.dims(); int64_t index_size = index.numel(); + int64_t grad_size = grad.numel(); int64_t inner_dim_size = 1; int64_t outer_dim_size = 1; + int64_t outer_dim_size_data = 1; int select_dim_size = index_dims[dim]; int grad_select_dim_size = grad_dims[dim]; for (int64_t i = 0; i < dim; ++i) { @@ -251,28 +367,124 @@ void gpu_scatter_input_grad_kernel(phi::DenseTensor self, for (int i = dim + 1; i < index_dims.size(); i++) { outer_dim_size *= index_dims[i]; + outer_dim_size_data *= grad_dims[i]; } int block = 512; int64_t n = inner_dim_size * select_dim_size * outer_dim_size; int64_t grid = (n + block - 1) / block; auto stream = reinterpret_cast(ctx).stream(); - + int shared_mem_size = sizeof(int) * grad_size; ScatterInputGradGPUKernel - <<>>(grad_data, - dim, - index_data, - inner_dim_size, - select_dim_size, - grad_select_dim_size, - outer_dim_size, - index_size); + <<>>(grad_data, + dim, + index_data, + select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_data, + index_size, + grad_size); +} + +template +__global__ void ScatterValueGradGPUKernel(tensor_t* grad_data, + int dim, + const tensor_t* self_data, + const index_t* index_data, + int select_dim_size, + int self_select_dim_size, + int grad_select_dim_size, + int64_t outer_dim_size, + int64_t outer_dim_size_self, + int64_t outer_dim_size_grad, + int64_t numel, + int64_t numel_data) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + extern __shared__ int thread_ids[]; + + if (tid == 0) { + for (int i = 0; i < numel_data; i++) { + thread_ids[i] = 0; + } + } + __syncthreads(); + int64_t i, j, k; + i = tid / (select_dim_size * outer_dim_size); + int64_t remind = tid % (select_dim_size * outer_dim_size); + j = remind / outer_dim_size; + k = remind % outer_dim_size; + index_t index = index_data[tid]; + int64_t replace_index_self = k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; + + atomicMax(thread_ids + replace_index_self, tid); + __syncthreads(); + + if (tid == thread_ids[replace_index_self]) { + int64_t replace_index_grad = k + j * outer_dim_size_grad + + i * outer_dim_size_grad * grad_select_dim_size; + grad_data[replace_index_grad] = self_data[replace_index_self]; + } +} +template +void gpu_scatter_value_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor grad, + const phi::DeviceContext& ctx) { + auto* self_data = self.data(); + auto* index_data = index.data(); + auto* grad_data = grad.data(); + + auto index_dims = index.dims(); + auto self_dims = self.dims(); + auto grad_dims = grad.dims(); + int64_t index_size = index.numel(); + int64_t self_size = self.numel(); + + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + int64_t outer_dim_size_self = 1; + int64_t outer_dim_size_grad = 1; + int select_dim_size = index_dims[dim]; + int self_select_dim_size = self_dims[dim]; + int grad_select_dim_size = grad_dims[dim]; + for (int64_t i = 0; i < dim; ++i) { + inner_dim_size *= index_dims[i]; + } + + for (int i = dim + 1; i < index_dims.size(); i++) { + outer_dim_size *= index_dims[i]; + outer_dim_size_self *= self_dims[i]; + outer_dim_size_grad *= grad_dims[i]; + } + + int block = 512; + int64_t n = inner_dim_size * select_dim_size * outer_dim_size; + int64_t grid = (n + block - 1) / block; + auto stream = reinterpret_cast(ctx).stream(); + int shared_mem_size = sizeof(int) * self_size; + ScatterValueGradGPUKernel + <<>>(grad_data, + dim, + self_data, + index_data, + select_dim_size, + self_select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_grad, + index_size, + self_size); } Instantiate_Template_Function(gpu_gather_kernel) Instantiate_Template_Function(gpu_scatter_assign_kernel) Instantiate_Template_Function(gpu_scatter_add_kernel) Instantiate_Template_Function(gpu_scatter_mul_kernel) Instantiate_Template_Function(gpu_scatter_input_grad_kernel) - + Instantiate_Template_Function(gpu_scatter_value_grad_kernel) } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.h b/paddle/phi/kernels/funcs/gather_scatter_functor.h index 56068f9459ebd5..054ccc196fcd00 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.h +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.h @@ -75,7 +75,14 @@ template void cpu_scatter_input_grad_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, - phi::DenseTensor result, + phi::DenseTensor grad, + const phi::DeviceContext& ctx); + +template +void cpu_scatter_value_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor grad, const phi::DeviceContext& ctx); template @@ -110,7 +117,14 @@ template void gpu_scatter_input_grad_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, - phi::DenseTensor result, + phi::DenseTensor grad, + const phi::DeviceContext& ctx); + +template +void gpu_scatter_value_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor grad, const phi::DeviceContext& ctx); } // namespace funcs diff --git a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu index 8321bcd1aa7acf..d86e0493786ebd 100644 --- a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu @@ -52,14 +52,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, value_grad->Resize(index.dims()); dev_ctx.template Alloc(value_grad); if (index_type == DataType::INT32) { - phi::funcs::gpu_gather_kernel( - out_grad, - axis, - index, - *value_grad, - dev_ctx); // the gradient of scatter is gather - } else if (index_type == DataType::INT64) { - phi::funcs::gpu_gather_kernel( + phi::funcs::gpu_scatter_value_grad_kernel( + out_grad, axis, index, *value_grad, dev_ctx); + } else { + phi::funcs::gpu_scatter_value_grad_kernel( out_grad, axis, index, *value_grad, dev_ctx); } } diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 685b10276c476f..81137ebb01fc58 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -5127,7 +5127,7 @@ def infer_broadcast_shape(arr, indices, axis): return broadcast_shape -def take_along_axis(arr, indices, axis): +def take_along_axis(arr, indices, axis, broadcast=True): """ Take values from the input array by given indices matrix along the designated axis. @@ -5136,6 +5136,7 @@ def take_along_axis(arr, indices, axis): indices (Tensor) : Indices to take along each 1d slice of arr. This must match the dimension of arr, and need to broadcast against arr. Supported data type are int and int64. axis (int) : The axis to take 1d slices along. + broadcast (bool, optional): whether the indices broadcast. Returns: Tensor, The indexed element, same dtype with arr @@ -5158,16 +5159,33 @@ def take_along_axis(arr, indices, axis): "`indices` and `arr` must have the same number of dimensions!" ) axis = non_negative_axis(arr, axis) - broadcast_shape = infer_broadcast_shape(arr, indices, axis) - if not broadcast_shape: - # if indices matrix have larger size than arr, arr should broadcast into indices shape. - broadcast_shape = indices.shape - if in_dynamic_or_pir_mode(): + if broadcast: + broadcast_shape = infer_broadcast_shape(arr, indices, axis) + if not broadcast_shape: + # if indices matrix have larger size than arr, arr should broadcast into indices shape. + broadcast_shape = indices.shape indices = paddle.broadcast_to(indices, broadcast_shape) broadcast_shape_list = list(broadcast_shape) broadcast_shape_list[axis] = list(arr.shape)[axis] broadcast_shape = tuple(broadcast_shape_list) arr = paddle.broadcast_to(arr, broadcast_shape) + else: + for i in range(len(arr.shape)): + if i != axis and arr.shape[i] < indices.shape[i]: + raise RuntimeError( + "Size does not match at dimension {} expected index {} to be smaller than self {} apart from dimension {}".format( + i, indices.shape, arr.shape, axis + ) + ) + + axis_max_size = arr.shape[axis] + if not (indices < axis_max_size).all(): + raise RuntimeError( + "one of element of indices is out of bounds for dimension {} with size {}".format( + axis, axis_max_size + ) + ) + if in_dynamic_or_pir_mode(): return _C_ops.take_along_axis(arr, indices, axis) else: check_variable_and_dtype( @@ -5187,11 +5205,6 @@ def take_along_axis(arr, indices, axis): check_variable_and_dtype( indices, 'index', ['int32', 'int64'], 'take_along_axis' ) - indices = paddle.broadcast_to(indices, broadcast_shape) - broadcast_shape_list = list(broadcast_shape) - broadcast_shape_list[axis] = list(arr.shape)[axis] - broadcast_shape = tuple(broadcast_shape_list) - arr = paddle.broadcast_to(arr, broadcast_shape) helper = LayerHelper('take_along_axis', **locals()) dtype = helper.input_dtype() result = helper.create_variable_for_type_inference(dtype) @@ -5204,7 +5217,15 @@ def take_along_axis(arr, indices, axis): return result -def put_along_axis(arr, indices, values, axis, reduce='assign'): +def put_along_axis( + arr, + indices, + values, + axis, + reduce='assign', + include_self=True, + broadcast=True, +): """ Put values into the destination array by given indices matrix along the designated axis. @@ -5214,6 +5235,8 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): and need to broadcast against arr. Supported data type are int and int64. axis (int) : The axis to put 1d slices along. reduce (str, optional): The reduce operation, default is 'assign', support 'add', 'assign', 'mul' and 'multiply'. + include_self (bool, optional): whether to reduce with the elements of arr. (Only support True now) + broadcast (bool, optional): whether to broadcast indices. Returns: Tensor, The indexed element, same dtype with arr @@ -5234,21 +5257,54 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): [60, 40, 50]]) """ + if not include_self: + raise ValueError("`include_self` is only support True now.") if len(arr.shape) != len(indices.shape): raise ValueError( "`indices` and `arr` must have the same number of dimensions!" ) axis = non_negative_axis(arr, axis) - broadcast_shape = infer_broadcast_shape(arr, indices, axis) - if in_dynamic_or_pir_mode(): - values = ( - paddle.to_tensor(values) - if not isinstance(values, (paddle.Tensor, paddle.pir.OpResult)) - else values - ) + if broadcast: + broadcast_shape = infer_broadcast_shape(arr, indices, axis) + if in_dynamic_or_pir_mode(): + values = ( + paddle.to_tensor(values) + if not isinstance(values, (paddle.Tensor, paddle.pir.OpResult)) + else values + ) if broadcast_shape: indices = paddle.broadcast_to(indices, broadcast_shape) values = paddle.broadcast_to(values, indices.shape) + else: + if isinstance(values, (paddle.Tensor, paddle.pir.OpResult)): + if len(indices.shape) != len(values.shape): + raise ValueError( + "`indices` and `values` must have the same number of dimensions!" + ) + for i in range(len(arr.shape)): + if ( + i != axis and arr.shape[i] < indices.shape[i] + ) or indices.shape[i] > values.shape[i]: + raise RuntimeError( + "Size does not match at dimension {} expected index {} to be smaller than self {} apart from dimension {} and to be smaller size than values {}".format( + i, indices.shape, arr.shape, axis, values.shape + ) + ) + else: + values = paddle.to_tensor(values).astype(arr.dtype) + elements = 1 + for num in values.shape: + elements *= num + if elements == 1: # paddle.pir.OpResult has no attribute 'size' + values = paddle.broadcast_to(values, indices.shape) + axis_max_size = arr.shape[axis] + if not (indices < axis_max_size).all(): + raise RuntimeError( + "one of element of indices is out of bounds for dimension {} with size {}".format( + axis, axis_max_size + ) + ) + if in_dynamic_or_pir_mode(): return _C_ops.put_along_axis(arr, indices, values, axis, reduce) else: check_variable_and_dtype( @@ -5268,9 +5324,6 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): check_variable_and_dtype( indices, 'index', ['int32', 'int64'], 'put_along_axis' ) - if broadcast_shape: - indices = paddle.broadcast_to(indices, broadcast_shape) - values = paddle.broadcast_to(values, indices.shape) helper = LayerHelper('put_along_axis', **locals()) dtype = helper.input_dtype() result = helper.create_variable_for_type_inference(dtype) diff --git a/test/legacy_test/prim_op_test.py b/test/legacy_test/prim_op_test.py index 743a856058a6f0..d1680c33e6c924 100644 --- a/test/legacy_test/prim_op_test.py +++ b/test/legacy_test/prim_op_test.py @@ -192,7 +192,10 @@ def parse_attri_value(name, op_inputs, op_proto_attrs): tmp = input_arguments[idx_of_op_proto_arguments] idx_of_op_proto_arguments += 1 else: - tmp = Empty() # use the default value + # tmp = Empty() # use the default value + tmp = parse_attri_value( + arg_name, op_proto_ins, op_proto_attrs + ) if isinstance(tmp, Empty): results.append(get_default(idx, api_defaults)) diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index 83194145bb18e7..43d2e80c25e24a 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -74,12 +74,52 @@ def init_data(self): self.x_shape = (10, 10, 10) self.value_type = "float16" self.value = np.array([99]).astype(self.value_type) - self.index_type = "int32" + self.index_type = "int64" self.index = np.array([[[0]]]).astype(self.index_type) self.axis = 1 self.axis_type = "int64" +class TestPutAlongAxisOpCase2(TestPutAlongAxisOp): + def setUp(self): + self.init_data() + self.reduce_op = "assign" + self.op_type = "put_along_axis" + self.python_api = paddle.tensor.put_along_axis + self.xnp = np.random.random(self.x_shape).astype(self.x_type) + # numpy put_along_axis is an inplace operation. + self.target = copy.deepcopy(self.xnp) + for i in range(5): + for j in range(5): + for k in range(5): + self.target[i, self.index[i, j, k], k] = self.value[i, j, k] + self.inputs = { + 'Input': self.xnp, + 'Index': self.index, + 'Value': self.value, + } + self.attrs = { + 'Axis': self.axis, + 'Reduce': self.reduce_op, + 'include_self': True, + 'broadcast': False, + } + self.outputs = {'Result': self.target} + + def init_data(self): + self.dtype = 'float32' + self.x_type = "float32" + self.x_shape = (10, 10, 10) + self.value_type = "float32" + self.value = ( + np.arange(1, 126).reshape((5, 5, 5)).astype(self.value_type) + ) + self.index_type = "int64" + self.index = np.zeros((5, 5, 5)).astype(self.index_type) + self.axis = 1 + self.axis_type = "int64" + + @unittest.skipIf( not core.is_compiled_with_cuda() or not core.is_bfloat16_supported(core.CUDAPlace(0)), @@ -271,6 +311,172 @@ def test_inplace_dygraph(self): pass +class TestPutAlongAxisAPICase4(unittest.TestCase): + def setUp(self): + np.random.seed(0) + self.shape = [3, 5] + self.index1_shape = [1, 4] + self.index_np1 = np.array([[0, 1, 2, 0]]).astype('int64') + self.index2_shape = [2, 3] + self.index_np2 = np.array([[0, 1, 2], [0, 1, 4]]).astype('int64') + self.x_np = np.zeros((3, 5)).astype(np.float32) + self.value_shape = [2, 5] + self.value = ( + np.arange(1, 11).reshape(self.value_shape).astype(np.float32) + ) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_api_dygraph(self): + def run(place): + paddle.disable_static(place) + x_tensor = paddle.to_tensor(self.x_np) + index_tensor1 = paddle.to_tensor(self.index_np1) + value_tensor = paddle.to_tensor(self.value) + out = paddle.put_along_axis( + x_tensor, index_tensor1, value_tensor, 0, 'assign', True, False + ) + out_ref = copy.deepcopy(self.x_np) + for i in range(self.index1_shape[0]): + for j in range(self.index1_shape[1]): + out_ref[self.index_np1[i, j], j] = self.value[i, j] + np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) + + # for ci coverage, numpy put_along_axis did not support argument of 'reduce' + paddle.put_along_axis( + x_tensor, index_tensor1, value_tensor, 0, 'mul', True, False + ) + paddle.put_along_axis( + x_tensor, index_tensor1, value_tensor, 0, 'add', True, False + ) + + index_tensor2 = paddle.to_tensor(self.index_np2) + out = paddle.put_along_axis( + x_tensor, index_tensor2, value_tensor, 1, 'assign', True, False + ) + out_ref = copy.deepcopy(self.x_np) + for i in range(self.index2_shape[0]): + for j in range(self.index2_shape[1]): + out_ref[i, self.index_np2[i, j]] = self.value[i, j] + np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) + + # for ci coverage, numpy put_along_axis did not support argument of 'reduce' + paddle.put_along_axis( + x_tensor, index_tensor2, value_tensor, 1, 'mul', True, False + ) + paddle.put_along_axis( + x_tensor, index_tensor2, value_tensor, 1, 'add', True, False + ) + + paddle.enable_static() + + for place in self.place: + run(place) + + @test_with_pir_api + def test_api_static(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + x1 = paddle.static.data('X', self.shape) + index1 = paddle.static.data('Index', self.index1_shape, "int64") + value_tensor = paddle.to_tensor(self.value) + out1 = paddle.put_along_axis( + x1, index1, value_tensor, 0, 'assign', True, False + ) + exe = paddle.static.Executor(place) + res = exe.run( + feed={ + 'X': self.x_np, + 'Value': self.value, + 'Index': self.index_np1, + }, + fetch_list=[out1], + ) + out_ref = copy.deepcopy(self.x_np) + for i in range(self.index1_shape[0]): + for j in range(self.index1_shape[1]): + out_ref[self.index_np1[i, j], j] = self.value[i, j] + + for out in res: + np.testing.assert_allclose(out, out_ref, rtol=0.001) + + with paddle.static.program_guard(paddle.static.Program()): + x2 = paddle.static.data('X', self.shape) + index2 = paddle.static.data('Index', self.index2_shape, "int64") + value_tensor = paddle.to_tensor(self.value) + out2 = paddle.put_along_axis( + x2, index2, value_tensor, 1, 'assign', True, False + ) + exe = paddle.static.Executor(place) + res = exe.run( + feed={ + 'X': self.x_np, + 'Value': self.value, + 'Index': self.index_np2, + }, + fetch_list=[out2], + ) + out_ref = copy.deepcopy(self.x_np) + for i in range(self.index2_shape[0]): + for j in range(self.index2_shape[1]): + out_ref[i, self.index_np2[i, j]] = self.value[i, j] + + for out in res: + np.testing.assert_allclose(out, out_ref, rtol=0.001) + + for place in self.place: + run(place) + + def test_error(self): + tensorx = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]).astype("float32") + indices = paddle.to_tensor([1]).astype("int32") + values = paddle.to_tensor([2]) + # len(arr.shape) != len(indices.shape) + try: + res = paddle.put_along_axis( + tensorx, indices, 1.0, 0, 'assign', True, False + ) + except Exception as error: + self.assertIsInstance(error, ValueError) + indices = paddle.to_tensor([[1]]).astype("int32") + # len(values.shape) != len(indices.shape) + try: + res = paddle.put_along_axis( + tensorx, indices, values, 0, 'assign', True, False + ) + except Exception as error: + self.assertIsInstance(error, ValueError) + indices = paddle.to_tensor( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] + ).astype("int32") + # indices too large + try: + res = paddle.put_along_axis( + tensorx, indices, 1.0, 0, 'assign', True, False + ) + except Exception as error: + self.assertIsInstance(error, RuntimeError) + indices = paddle.to_tensor([[10]]).astype("int32") + # the element of indices out of range + try: + res = paddle.put_along_axis( + tensorx, indices, 1.0, 0, 'assign', True, False + ) + except Exception as error: + self.assertIsInstance(error, RuntimeError) + + # use includ_self=False + try: + res = paddle.put_along_axis( + tensorx, indices, 1.0, 0, 'assign', False + ) + except Exception as error: + self.assertIsInstance(error, ValueError) + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/test/legacy_test/test_take_along_axis_op.py b/test/legacy_test/test_take_along_axis_op.py index 3e687fcdf2a24e..dc73ae34aeea43 100644 --- a/test/legacy_test/test_take_along_axis_op.py +++ b/test/legacy_test/test_take_along_axis_op.py @@ -76,6 +76,34 @@ def init_data(self): self.axis_type = "int64" +class TestTakeAlongAxisOp(OpTest): + def setUp(self): + self.init_data() + self.op_type = "take_along_axis" + self.python_api = paddle.tensor.take_along_axis + self.check_cinn = True + self.xnp = np.random.random(self.x_shape).astype(self.x_type) + self.target = np.zeros((2, 3, 4)).astype(self.x_type) + for i in range(2): + for j in range(3): + for k in range(4): + self.target[i, j, k] = self.xnp[i, j, self.index[i, j, k]] + self.inputs = { + 'Input': self.xnp, + 'Index': self.index, + } + self.attrs = {'Axis': self.axis, 'broadcast': False} + self.outputs = {'Result': self.target} + + def init_data(self): + self.x_type = "float64" + self.x_shape = (10, 10, 10) + self.index_type = "int64" + self.index = np.random.randint(0, 10, (2, 3, 4)).astype(self.index_type) + self.axis = 2 + self.axis_type = "int64" + + @unittest.skipIf( not core.is_compiled_with_cuda() or not core.is_bfloat16_supported(core.CUDAPlace(0)), @@ -210,6 +238,67 @@ def setUp(self): self.place.append(paddle.CUDAPlace(0)) +class TestTakeAlongAxisAPICase2(unittest.TestCase): + def setUp(self): + np.random.seed(0) + self.shape = [3, 3] + self.index_shape = [1, 3] + self.index_np = np.array([[0, 1, 2]]).astype('int64') + self.x_np = np.random.random(self.shape).astype(np.float32) + self.place = [paddle.CPUPlace()] + self.axis = 0 + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + @test_with_pir_api + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', self.shape) + index = paddle.static.data('Index', self.index_shape, "int64") + out = paddle.take_along_axis(x, index, self.axis, False) + exe = paddle.static.Executor(self.place[0]) + res = exe.run( + feed={'X': self.x_np, 'Index': self.index_np}, fetch_list=[out] + ) + out_ref = np.zeros_like(self.index_np, dtype=self.x_np.dtype) + for i in range(self.index_shape[0]): + for j in range(self.index_shape[1]): + out_ref[i, j] = self.x_np[self.index_np[i, j], j] + for out in res: + np.testing.assert_allclose(out, out_ref, rtol=0.001) + + def test_api_dygraph(self): + paddle.disable_static(self.place[0]) + x_tensor = paddle.to_tensor(self.x_np) + self.index = paddle.to_tensor(self.index_np) + out = paddle.take_along_axis(x_tensor, self.index, self.axis, False) + out_ref = np.zeros_like(self.index_np, dtype=self.x_np.dtype) + for i in range(self.index_shape[0]): + for j in range(self.index_shape[1]): + out_ref[i, j] = self.x_np[self.index_np[i, j], j] + np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) + paddle.enable_static() + + def test_error(self): + paddle.disable_static(self.place[0]) + tensorx = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]).astype("float32") + indices = paddle.to_tensor([1]).astype("int32") + # len(arr.shape) != len(indices.shape) + with self.assertRaises(ValueError): + res = paddle.take_along_axis(tensorx, indices, 0, False) + # the element of indices out of range + with self.assertRaises(RuntimeError): + indices = paddle.to_tensor([[100]]).astype("int32") + res = paddle.take_along_axis(tensorx, indices, 0, False) + # the shape of indices doesn't match + with self.assertRaises(RuntimeError): + indices = paddle.to_tensor( + [[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0]] + ).astype("int32") + res = paddle.take_along_axis(tensorx, indices, 0, False) + + if __name__ == "__main__": paddle.enable_static() unittest.main()