diff --git a/paddle/phi/infermeta/spmd_rules/slice.cc b/paddle/phi/infermeta/spmd_rules/slice.cc index 1ec057a1734e5..73caa2e65aa45 100644 --- a/paddle/phi/infermeta/spmd_rules/slice.cc +++ b/paddle/phi/infermeta/spmd_rules/slice.cc @@ -28,6 +28,7 @@ using phi::distributed::auto_parallel::str_join; SpmdInfo SliceInferSpmdBase(const DistMetaTensor& input, const std::vector& axes) { + // Step0: Verify input args based on slice logic auto input_shape = phi::vectorize(input.dims()); int input_ndim = input_shape.size(); auto input_dist_attr_src = input.dist_attr(); @@ -40,32 +41,39 @@ SpmdInfo SliceInferSpmdBase(const DistMetaTensor& input, input_ndim, input_dims_mapping.size())); + // Step1: Build Einsum Notation std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; - std::string input_axes = alphabet.substr(0, input_ndim); - std::string special_axes = alphabet.substr(input_ndim); - for (int i = 0; i < static_cast(axes.size()); i++) { - int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; - input_axes[axis] = special_axes[i]; - } + // get einsum notation for input + std::string input_axes = alphabet.substr(0, input_ndim); + // get einsum notation for output std::string out_axes(input_axes); for (int i = 0; i < static_cast(axes.size()); i++) { int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + // the sliced axis cannot be sharded, set its notation + // with the special '1' to set its dim mapping to -1. out_axes[axis] = '1'; } + // Step2: Sharding Propogation + // Step2.1: merge input shardings std::unordered_map axis_to_dim_map = ShardingMergeForTensors({{input_axes, input_dims_mapping}}); + // Step2.2: infer output dims mapping from merged input dims mapping std::vector out_dims_mapping = GetDimsMappingForAxes(out_axes, axis_to_dim_map); + // get the dist attributes for output. the sliced + // cannot be sharded, if it is sharded, set it to replicated. TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(input_dist_attr_src); out_dist_attr.set_dims_mapping(out_dims_mapping); + // Step2.3 get new dist attribute for input. the sliced + // cannot be sharded, if it is sharded, set it to replicated. TensorDistAttr input_dist_attr_dst(input_dist_attr_src); for (int i = 0; i < static_cast(axes.size()); i++) { int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; @@ -76,6 +84,7 @@ SpmdInfo SliceInferSpmdBase(const DistMetaTensor& input, VLOG(4) << "SliceInferSpmd:"; VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes; VLOG(4) << "Input shape: [" << str_join(input_shape) << "] " + << "axes: [" << str_join(axes) << "] " << "src_dims_mapping: [" << str_join(input_dist_attr_src.dims_mapping()) << "] " << "dst_dims_mapping: [" << str_join(input_dims_mapping) << "]"; @@ -92,12 +101,15 @@ SpmdInfo SliceInferSpmd(const DistMetaTensor& input, const std::vector& ends, const std::vector& infer_flags, const std::vector& decrease_axis) { + // starts, ends, infer_flags and decrease_axis have no impact on the + // derivation, only to align with the definition in phi api return SliceInferSpmdBase(input, axes); } SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input, const DistMetaTensor& output, const std::vector& axes) { + // Step0: Verify input args based on slice logic auto output_shape = phi::vectorize(output.dims()); int out_ndim = output_shape.size(); auto out_dist_attr = output.dist_attr(); @@ -123,17 +135,24 @@ SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input, out_ndim, out_dims_mapping_size)); + // Step1: Build Einsum Notation std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + // get einsum notation for input std::string input_axes = alphabet.substr(0, input_ndim); - std::string special_axes = alphabet.substr(input_ndim); + + // get einsum notation for output + std::string out_axes(input_axes); for (int i = 0; i < static_cast(axes.size()); i++) { int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; - input_axes[axis] = special_axes[i]; + // the sliced axis cannot be sharded, set its notation + // with the special '1' to set its dim mapping to -1. + input_axes[axis] = '1'; } - std::string out_axes(input_axes); - + // Step2: Sharding Propogation + // Step2.1: merge output shardings std::vector>> axes_sharding_info; std::vector out_dims_mapping = output.dist_attr().dims_mapping(); axes_sharding_info.emplace_back(std::make_pair(out_axes, out_dims_mapping)); @@ -141,12 +160,13 @@ SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input, std::unordered_map axis_to_dim_map = ShardingMergeForTensors(axes_sharding_info); + // Step2.2: infer input dims mapping from output dims mapping. the sliced + // cannot be sharded, if it is sharded, set it to replicated. input_dims_mapping = GetDimsMappingForAxes(input_axes, axis_to_dim_map, true); - for (int i = 0; i < static_cast(axes.size()); i++) { - int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; - input_dims_mapping[axis] = -1; - } input_dist_attr.set_dims_mapping(input_dims_mapping); + + // step2.3 get new dist attribute for output. the sliced + // cannot be sharded, if it is sharded, set it to replicated. out_dims_mapping = GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); for (int i = 0; i < static_cast(axes.size()); i++) { int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; @@ -158,6 +178,7 @@ SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input, VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes; VLOG(4) << "Output" << " shape: [" << str_join(phi::vectorize(output.dims())) << "] " + << "axes: [" << str_join(axes) << "] " << "src_dims_mapping: [" << str_join(output.dist_attr().dims_mapping()) << "] " << "dst_dims_mapping: [" << str_join(out_dist_attr.dims_mapping()) @@ -175,6 +196,8 @@ SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input, const std::vector& ends, const std::vector& infer_flags, const std::vector& decrease_axis) { + // starts, ends, infer_flags and decrease_axis have no impact on the + // derivation, only to align with the definition in phi api return SliceInferSpmdReverseBase(input, output, axes); } @@ -184,6 +207,8 @@ SpmdInfo SliceInferSpmdDynamic(const DistMetaTensor& input, const IntArray& ends, const std::vector& infer_flags, const std::vector& decrease_axis) { + // starts, ends, infer_flags and decrease_axis have no impact on the + // derivation, only to align with the definition in phi api std::vector start_indexes(starts.GetData().begin(), starts.GetData().end()); std::vector end_indexes(ends.GetData().begin(), ends.GetData().end()); @@ -193,6 +218,7 @@ SpmdInfo SliceInferSpmdDynamic(const DistMetaTensor& input, SpmdInfo SliceGradInferBase(const DistMetaTensor& input, const DistMetaTensor& out_grad, const std::vector& axes) { + // Step0: Verify input args based on slice logic auto input_dist_attr = input.dist_attr(); auto out_dist_attr = out_grad.dist_attr(); input_dist_attr = UnShardTensorDims(input_dist_attr, axes); @@ -220,17 +246,18 @@ SpmdInfo SliceGradInferBase(const DistMetaTensor& input, out_ndim, out_dims_mapping_size)); + // Step1: Build Einsum Notation std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + // get einsum notation for input std::string align_axes = alphabet.substr(0, input_ndim); std::string input_axes = align_axes; - std::string special_axes = alphabet.substr(input_ndim); - for (int i = 0; i < static_cast(axes.size()); i++) { - int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; - input_axes[axis] = special_axes[i]; - } + // get einsum notation for output std::string out_axes(input_axes); + // Step2: Sharding Propogation + // Step2.1: merge input shardings std::vector>> axes_sharding_info; axes_sharding_info.emplace_back( std::make_pair(out_axes, out_dist_attr.dims_mapping())); @@ -238,8 +265,12 @@ SpmdInfo SliceGradInferBase(const DistMetaTensor& input, std::make_pair(input_axes, input_dist_attr.dims_mapping())); std::unordered_map axis_to_dim_map = ShardingMergeForTensors(axes_sharding_info); + + // Step2.2: infer output dims mapping from merged input dims mapping auto aligned_dim_mapping = GetDimsMappingForAxes(align_axes, axis_to_dim_map, true); + + // get the dist attributes for output TensorDistAttr aligned_dist_attr = CopyTensorDistAttrForOutput(out_dist_attr); input_dist_attr.set_dims_mapping(aligned_dim_mapping); out_dist_attr.set_dims_mapping(aligned_dim_mapping); @@ -276,6 +307,8 @@ SpmdInfo SliceGradInferSpmdDynamic(const DistMetaTensor& input, const IntArray& ends, const std::vector& infer_flags, const std::vector& decrease_axis) { + // starts, ends, infer_flags and decrease_axis have no impact on the + // derivation, only to align with the definition in phi api return SliceGradInferBase(input, out_grad, axes); } @@ -284,6 +317,8 @@ SpmdInfo StridedSliceInferSpmdDynamic(const DistMetaTensor& input, const IntArray& starts, const IntArray& ends, const IntArray& strides) { + // starts, ends and strides have no impact on the derivation, + // only to align with the definition in phi api std::vector axes_bridge(axes.begin(), axes.end()); return SliceInferSpmdBase(input, axes_bridge); } @@ -294,6 +329,8 @@ SpmdInfo StridedSliceGradInferSpmdDynamic(const DistMetaTensor& input, const IntArray& starts, const IntArray& ends, const IntArray& strides) { + // starts, ends and strides have no impact on the derivation, + // only to align with the definition in phi api std::vector axes_bridge(axes.begin(), axes.end()); return SliceGradInferBase(input, out_grad, axes_bridge); }