Skip to content

Commit

Permalink
[Semi-Auto] Support dynamic shape in reshape spmd rule (#58097)
Browse files Browse the repository at this point in the history
* support dynamic input in reshape spmd rule

* remove the modification in dist_reshape.py

* small fix
  • Loading branch information
pkuzyc authored Oct 23, 2023
1 parent 0b4ed68 commit f984ed1
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 16 deletions.
69 changes: 53 additions & 16 deletions paddle/phi/infermeta/spmd_rules/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::vector<int64_t> InferTargetShape(const std::vector<int64_t>& shape,
PADDLE_ENFORCE_EQ(
product,
len,
phi::errors::InvalidArgument("The total size are not matched"));
phi::errors::InvalidArgument("The total size are not matched."));
return std::vector<int64_t>(shape);
} else {
std::vector<int64_t> new_shape(shape);
Expand All @@ -59,7 +59,7 @@ std::vector<int64_t> InferTargetShape(const std::vector<int64_t>& shape,
PADDLE_ENFORCE_EQ(len % infer_size,
0,
phi::errors::InvalidArgument(
"The total is not diviable by infer_size"));
"The total is not diviable by infer_size."));
new_shape[infer_idx] = infer_size;
return new_shape;
}
Expand Down Expand Up @@ -143,8 +143,11 @@ std::vector<DimTrans*> MakeReshapeDimTrans(
SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x,
const std::vector<int64_t>& shape) {
// Step0: Verify input args based on reshape logic
auto src_shape = phi::vectorize(x.dims());
int x_ndim = src_shape.size();
VLOG(2) << "Debug Info for reshape";
VLOG(2) << "shape: " << str_join(shape);
auto x_shape = phi::vectorize(x.dims());
int x_ndim = x_shape.size();
int out_ndim = shape.size();
auto x_dist_attr_src = x.dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
Expand All @@ -154,20 +157,31 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x,
"dims_mapping size [%d] are not matched.",
x_ndim,
x_dims_mapping.size()));
VLOG(4) << "ReshapeInferSpmd: X shape: [" << str_join(x_shape) << "]";
VLOG(4) << "Out shape: [" << str_join(shape) << "]";

// Step1: Build the transformation from
// the original shape to the target shape

// handle the case of dynamic shape, like [-1, -1, ...] --> [0, 0, ...].
// This is used in inference but reshape allows only one '-1' in the
// target shape, so set the shape to a special value '256'
for (int i = 0; i < x_ndim; i++) {
if (x_shape[i] == -1) {
x_shape[i] = 256;
}
}

// handle the '0' values in target shape, '0' indicates
// that the target shape is equal to the source shape
std::vector<int64_t> tgt_shape(shape);
for (int64_t i = 0, n = static_cast<int64_t>(tgt_shape.size()); i < n; i++) {
for (int64_t i = 0; i < out_ndim; i++) {
if (tgt_shape[i] == 0) {
tgt_shape[i] = src_shape[i];
tgt_shape[i] = x_shape[i];
}
}

std::vector<DimTrans*> trans = MakeReshapeDimTrans(src_shape, tgt_shape);
std::vector<DimTrans*> trans = MakeReshapeDimTrans(x_shape, tgt_shape);

// Step2: Infer the dims mapping of input (if reshard is
// needed) and output from the dimension transformation.
Expand All @@ -181,17 +195,14 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x,
TensorDistAttr out_dist_attr(x_dist_attr_src);
out_dist_attr.set_dims_mapping(dims_mapping_vec[1]);

VLOG(4) << "ReshapeInferSpmd: X shape: [" << str_join(src_shape)
<< "] Out shape: [" << str_join(tgt_shape) << "]";
VLOG(4) << "Transformation from input to output:";
for (int64_t i = 0, n = static_cast<int64_t>(trans.size()); i < n; i++) {
DimTrans* t = trans[i];
VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string();
}
VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping)
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0])
<< "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1])
<< "]\n\n";
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]";
VLOG(4) << "Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n";

CleanUp();

Expand All @@ -201,9 +212,12 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x,
SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& out,
const std::vector<int64_t>& shape) {
VLOG(2) << "Debug Info for reshape_reverse";
VLOG(2) << "shape: " << str_join(shape);
// Step0: Verify input args based on reshape logic
auto x_shape = phi::vectorize(x.dims());
auto out_shape = phi::vectorize(out.dims());
int x_ndim = x_shape.size();
int out_ndim = out_shape.size();
auto out_dist_attr_src = out.dist_attr();
std::vector<int64_t> out_dims_mapping = out_dist_attr_src.dims_mapping();
Expand All @@ -214,14 +228,39 @@ SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x,
"dims_mapping size [%d] are not matched.",
out_ndim,
out_dims_mapping.size()));
VLOG(4) << "ReshapeInferSpmdReverse: Out shape: [" << str_join(out_shape)
<< "], X shape: [" << str_join(x_shape) << "]";

// Step1: Build the transformation from the output shape
// to original shape. This function infers the dims mapping
// from output to input, we first get the transformation
// from output to input so that we can infer the dims mapping
// with the map from output axes to input axes.
// Shapes in InferSpmdReverse don't contain -1 or 0, so they will
// not be modified and we can directly use them.

// handle the case of dynamic shape, like [-1, -1, ...] --> [0, 0, ...].
// This is used in inference but reshape allows only one '-1' in the
// target shape, so set the shape to a special value '256'
for (int i = 0; i < x_ndim; i++) {
if (x_shape[i] == -1) {
x_shape[i] = 256;
}
}

// handle the '0' values in target shape, '0' indicates
// that the target shape is equal to the source shape
std::vector<int64_t> tgt_shape(shape);
for (int64_t i = 0; i < out_ndim; i++) {
if (shape[i] == 0) {
out_shape[i] = x_shape[i];
}
}

// The out_shape may contain '-1', which will cause error
// when inferring the transformation from out_shape to
// x_shape, so infer the '-1' value before inferrng DimTrans
int64_t nelm = std::accumulate(
x_shape.begin(), x_shape.end(), 1, std::multiplies<int64_t>());
out_shape = InferTargetShape(out_shape, nelm);
std::vector<DimTrans*> trans = MakeReshapeDimTrans(out_shape, x_shape);

// Step2: Infer the dims mapping of input with
Expand All @@ -236,8 +275,6 @@ SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x,
TensorDistAttr x_dist_attr(x.dist_attr());
x_dist_attr.set_dims_mapping(dims_mapping_vec[1]);

VLOG(4) << "ReshapeInferSpmdReverse: Out shape: [" << str_join(out_shape)
<< "] X shape: [" << str_join(x_shape) << "]";
VLOG(4) << "Transformation from output to input:";
for (int64_t i = 0, n = trans.size(); i < n; i++) {
DimTrans* t = trans[i];
Expand Down
105 changes: 105 additions & 0 deletions test/auto_parallel/spmd_rules/test_reshape_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,54 @@ def test_reshape_infer_forward(self):
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)

# shape: [-1, -1, 3072] --> [0, 0, -1, 192]
# dims_mapping: [0, 1, -1] --> [0, 1, -1], [0, 1, -1, -1]
self.x_dist_tensor_spec.shape = [-1, -1, 3072]
self.attrs["shape"] = [0, 0, -1, 192]
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1])
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['shape']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)

# shape: [-1, -1, 3072] --> [0, 0, -1, 192]
# dims_mapping: [0, -1, 1] --> [0, -1, -1], [0, -1, -1, -1]
self.x_dist_tensor_spec.shape = [-1, -1, 3072]
self.attrs["shape"] = [0, 0, -1, 192]
self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1])
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['shape']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, -1]
)

# shape: [-1, -1, 3072] --> [0, 0, -1, 192]
# dims_mapping: [1, -1, 0] --> [1, -1, 0], [1, -1, 0, -1]
self.x_dist_tensor_spec.shape = [-1, -1, 3072]
self.attrs["shape"] = [0, 0, -1, 192]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, 0])
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['shape']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1]
)

# shape: [6, 12, 48, 24] --> [3, 24, 6, -1, -1]
# raise error
self.attrs["shape"] = [3, 24, 6, -1, -1]
Expand Down Expand Up @@ -454,6 +502,63 @@ def test_reshape_infer_backward(self):
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, 0]
)

# shape: [8, 1024, 3072] --> [0, 0, -1, 192] (input --> output)
# dims_mapping: [0, 1, -1, -1] --> [0, 1, -1], [0, 1, -1, -1] (output --> input, output)
self.x_dist_tensor_spec.shape = [8, 1024, 3072]
self.output_dist_tensor_spec.shape = [0, 0, -1, 192]
self.attrs["shape"] = [0, 0, -1, 192]
self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['shape'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)

# shape: [-1, -1, 3072] --> [0, 0, -1, 192] (input --> output)
# dims_mapping: [0, 1, -1, -1] --> [0, 1, -1], [0, 1, -1, -1] (output --> input, output)
self.x_dist_tensor_spec.shape = [-1, -1, 3072]
self.output_dist_tensor_spec.shape = [0, 0, -1, 192]
self.attrs["shape"] = [0, 0, -1, 192]
self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['shape'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)

# shape: [-1, -1, 3072] --> [0, 0, -1, 192] (input --> output)
# dims_mapping: [0, -1, 1, -1] --> [0, -1, 1], [0, -1, 1, -1] (output --> input, output)
self.x_dist_tensor_spec.shape = [-1, -1, 3072]
self.output_dist_tensor_spec.shape = [0, 0, -1, 192]
self.attrs["shape"] = [0, 0, -1, 192]
self.output_dist_tensor_spec.set_dims_mapping([0, -1, 1, -1])
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['shape'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, 1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, -1, 1, -1]
)


if __name__ == "__main__":
unittest.main()

0 comments on commit f984ed1

Please sign in to comment.