Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add comments #59372

Merged
merged 3 commits into from
Dec 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 56 additions & 19 deletions paddle/phi/infermeta/spmd_rules/slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using phi::distributed::auto_parallel::str_join;

SpmdInfo SliceInferSpmdBase(const DistMetaTensor& input,
const std::vector<int64_t>& axes) {
// Step0: Verify input args based on slice logic
auto input_shape = phi::vectorize(input.dims());
int input_ndim = input_shape.size();
auto input_dist_attr_src = input.dist_attr();
Expand All @@ -40,32 +41,39 @@ SpmdInfo SliceInferSpmdBase(const DistMetaTensor& input,
input_ndim,
input_dims_mapping.size()));

// Step1: Build Einsum Notation
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
std::string input_axes = alphabet.substr(0, input_ndim);
std::string special_axes = alphabet.substr(input_ndim);

for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
input_axes[axis] = special_axes[i];
}
// get einsum notation for input
std::string input_axes = alphabet.substr(0, input_ndim);

// get einsum notation for output
std::string out_axes(input_axes);

for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
// the sliced axis cannot be sharded, set its notation
// with the special '1' to set its dim mapping to -1.
out_axes[axis] = '1';
}

// Step2: Sharding Propogation
// Step2.1: merge input shardings
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{input_axes, input_dims_mapping}});

// Step2.2: infer output dims mapping from merged input dims mapping
std::vector<int64_t> out_dims_mapping =
GetDimsMappingForAxes(out_axes, axis_to_dim_map);

// get the dist attributes for output. the sliced
// cannot be sharded, if it is sharded, set it to replicated.
TensorDistAttr out_dist_attr =
CopyTensorDistAttrForOutput(input_dist_attr_src);
out_dist_attr.set_dims_mapping(out_dims_mapping);

// Step2.3 get new dist attribute for input. the sliced
// cannot be sharded, if it is sharded, set it to replicated.
TensorDistAttr input_dist_attr_dst(input_dist_attr_src);
for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
Expand All @@ -76,6 +84,7 @@ SpmdInfo SliceInferSpmdBase(const DistMetaTensor& input,
VLOG(4) << "SliceInferSpmd:";
VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes;
VLOG(4) << "Input shape: [" << str_join(input_shape) << "] "
<< "axes: [" << str_join(axes) << "] "
<< "src_dims_mapping: ["
<< str_join(input_dist_attr_src.dims_mapping()) << "] "
<< "dst_dims_mapping: [" << str_join(input_dims_mapping) << "]";
Expand All @@ -92,12 +101,15 @@ SpmdInfo SliceInferSpmd(const DistMetaTensor& input,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
// starts, ends, infer_flags and decrease_axis have no impact on the
// derivation, only to align with the definition in phi api
return SliceInferSpmdBase(input, axes);
}

SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input,
const DistMetaTensor& output,
const std::vector<int64_t>& axes) {
// Step0: Verify input args based on slice logic
auto output_shape = phi::vectorize(output.dims());
int out_ndim = output_shape.size();
auto out_dist_attr = output.dist_attr();
Expand All @@ -123,30 +135,38 @@ SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input,
out_ndim,
out_dims_mapping_size));

// Step1: Build Einsum Notation
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";

// get einsum notation for input
std::string input_axes = alphabet.substr(0, input_ndim);
std::string special_axes = alphabet.substr(input_ndim);

// get einsum notation for output
std::string out_axes(input_axes);

for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
input_axes[axis] = special_axes[i];
// the sliced axis cannot be sharded, set its notation
// with the special '1' to set its dim mapping to -1.
input_axes[axis] = '1';
}

std::string out_axes(input_axes);

// Step2: Sharding Propogation
// Step2.1: merge output shardings
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
std::vector<int64_t> out_dims_mapping = output.dist_attr().dims_mapping();
axes_sharding_info.emplace_back(std::make_pair(out_axes, out_dims_mapping));

std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);

// Step2.2: infer input dims mapping from output dims mapping. the sliced
// cannot be sharded, if it is sharded, set it to replicated.
input_dims_mapping = GetDimsMappingForAxes(input_axes, axis_to_dim_map, true);
for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
input_dims_mapping[axis] = -1;
}
input_dist_attr.set_dims_mapping(input_dims_mapping);

// step2.3 get new dist attribute for output. the sliced
// cannot be sharded, if it is sharded, set it to replicated.
out_dims_mapping = GetDimsMappingForAxes(out_axes, axis_to_dim_map, true);
for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
Expand All @@ -158,6 +178,7 @@ SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input,
VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes;
VLOG(4) << "Output"
<< " shape: [" << str_join(phi::vectorize(output.dims())) << "] "
<< "axes: [" << str_join(axes) << "] "
<< "src_dims_mapping: ["
<< str_join(output.dist_attr().dims_mapping()) << "] "
<< "dst_dims_mapping: [" << str_join(out_dist_attr.dims_mapping())
Expand All @@ -175,6 +196,8 @@ SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
// starts, ends, infer_flags and decrease_axis have no impact on the
// derivation, only to align with the definition in phi api
return SliceInferSpmdReverseBase(input, output, axes);
}

Expand All @@ -184,6 +207,8 @@ SpmdInfo SliceInferSpmdDynamic(const DistMetaTensor& input,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
// starts, ends, infer_flags and decrease_axis have no impact on the
// derivation, only to align with the definition in phi api
std::vector<int> start_indexes(starts.GetData().begin(),
starts.GetData().end());
std::vector<int> end_indexes(ends.GetData().begin(), ends.GetData().end());
Expand All @@ -193,6 +218,7 @@ SpmdInfo SliceInferSpmdDynamic(const DistMetaTensor& input,
SpmdInfo SliceGradInferBase(const DistMetaTensor& input,
const DistMetaTensor& out_grad,
const std::vector<int64_t>& axes) {
// Step0: Verify input args based on slice logic
auto input_dist_attr = input.dist_attr();
auto out_dist_attr = out_grad.dist_attr();
input_dist_attr = UnShardTensorDims(input_dist_attr, axes);
Expand Down Expand Up @@ -220,26 +246,31 @@ SpmdInfo SliceGradInferBase(const DistMetaTensor& input,
out_ndim,
out_dims_mapping_size));

// Step1: Build Einsum Notation
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";

// get einsum notation for input
std::string align_axes = alphabet.substr(0, input_ndim);
std::string input_axes = align_axes;
std::string special_axes = alphabet.substr(input_ndim);

for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
input_axes[axis] = special_axes[i];
}
// get einsum notation for output
std::string out_axes(input_axes);

// Step2: Sharding Propogation
// Step2.1: merge input shardings
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
axes_sharding_info.emplace_back(
std::make_pair(out_axes, out_dist_attr.dims_mapping()));
axes_sharding_info.emplace_back(
std::make_pair(input_axes, input_dist_attr.dims_mapping()));
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);

// Step2.2: infer output dims mapping from merged input dims mapping
auto aligned_dim_mapping =
GetDimsMappingForAxes(align_axes, axis_to_dim_map, true);

// get the dist attributes for output
TensorDistAttr aligned_dist_attr = CopyTensorDistAttrForOutput(out_dist_attr);
input_dist_attr.set_dims_mapping(aligned_dim_mapping);
out_dist_attr.set_dims_mapping(aligned_dim_mapping);
Expand Down Expand Up @@ -276,6 +307,8 @@ SpmdInfo SliceGradInferSpmdDynamic(const DistMetaTensor& input,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
// starts, ends, infer_flags and decrease_axis have no impact on the
// derivation, only to align with the definition in phi api
return SliceGradInferBase(input, out_grad, axes);
}

Expand All @@ -284,6 +317,8 @@ SpmdInfo StridedSliceInferSpmdDynamic(const DistMetaTensor& input,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides) {
// starts, ends and strides have no impact on the derivation,
// only to align with the definition in phi api
std::vector<int64_t> axes_bridge(axes.begin(), axes.end());
return SliceInferSpmdBase(input, axes_bridge);
}
Expand All @@ -294,6 +329,8 @@ SpmdInfo StridedSliceGradInferSpmdDynamic(const DistMetaTensor& input,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides) {
// starts, ends and strides have no impact on the derivation,
// only to align with the definition in phi api
std::vector<int64_t> axes_bridge(axes.begin(), axes.end());
return SliceGradInferBase(input, out_grad, axes_bridge);
}
Expand Down