Skip to content

Commit

Permalink
test only
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Nov 10, 2023
1 parent f9bc8d2 commit bf3cd8c
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx,
SetValue(out, in.value());
SetDistProps(out, in.dims(), in_dist_attr);

VLOG(0) << "Same nd mesh, in_dist_attr: " << in_dist_attr;
VLOG(0) << "Same nd mesh, out_dist_attr: " << out_dist_attr;
// 1. change all the partial status to replicated status if needed
if (in_dist_attr.is_partial()) {
// Copy in_dist_attr.partial_status to avoid overwriting the value of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx,
VLOG(3) << "Call RToPReshardFunction Eval";
const auto& out_process_mesh = out_dist_attr.process_mesh();
int64_t local_rank = GetCurRankCoordInMesh(out_process_mesh)[0];
IntArray shape(in.dims().Get(), in.dims().size());
const auto& in_reduce_type = out_dist_attr.partial_status().at(0);

if (local_rank != 0) {
Expand All @@ -59,6 +58,7 @@ void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx,
dev_ctx, Assign, in.value(), GetMutableTensor(out));
} else {
// reset the physical tensor to zero
IntArray shape(in.local_dims().Get(), in.local_dims().size());
RESHARD_FUNCTOR(
dev_ctx, Full, in.dtype(), shape, 0, GetMutableTensor(out));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
do { \
if (phi::CPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name << "` in Resharding on GPU."; \
VLOG(4) << "Call `" << #fn_name << "` in Resharding on CPU."; \
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const CPUContext&>(*dev_ctx), \
__VA_ARGS__); \
})); \
} else if (phi::GPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name << "` in Resharding on CPU."; \
VLOG(4) << "Call `" << #fn_name << "` in Resharding on GPU."; \
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const GPUContext&>(*dev_ctx), \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
VLOG(3) << "Call SameStatusReshardFunction Eval";
const auto& in_dist_attr = in.dist_attr();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/spmd_rules/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ SpmdInfo ElementwiseBinaryInferSpmd(const DistMetaTensor& x,
out_dist_attr.set_dims_mapping(out_dims_mapping);

// Step2.3: Update inputs' dims mapping with merged one.
TensorDistAttr x_dist_attr_dst(x_dist_attr_src);
TensorDistAttr y_dist_attr_dst(y_dist_attr_src);
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
TensorDistAttr y_dist_attr_dst = CopyTensorDistAttrForOutput(y_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(x_axes, axis_to_dim_map));
y_dist_attr_dst.set_dims_mapping(
Expand Down
40 changes: 31 additions & 9 deletions paddle/phi/infermeta/spmd_rules/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,22 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x,
y.dist_attr()));
};

auto confirm_dist_attr_with_arg_same_fn = [&](const ArgDistAttr& x_dist_attr,
const ArgDistAttr& y_dist_attr,
const char* debug_msg) {
const auto& x_single_dist_attr = get_attr(x_dist_attr);
const auto& y_single_dist_attr = get_attr(y_dist_attr);
PADDLE_ENFORCE_EQ(
DistAttrsAreBasicallyEqual(x_single_dist_attr, y_single_dist_attr),
true,
phi::errors::Unavailable("The matmul grad infer spmd `%s` verify "
"error: left dist attr is %s, "
"right dist attr is %s.",
debug_msg,
x_single_dist_attr,
y_single_dist_attr));
};

// TODO(chenweihang): Now for the case where the forward input generates
// an intermediate value through Reshard, because the intermediate value
// is destroyed after the forward calculation is completed, the x and y
Expand Down Expand Up @@ -343,6 +359,9 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x,
confirm_dist_attr_same_fn(
dy_spmd_info.first[0], out_grad, "trans x&y: dy-out_grad");
confirm_dist_attr_same_fn(dy_spmd_info.first[1], x, "trans x&y: dy-x");
return {
{dy_spmd_info.first[1], dx_spmd_info.first[0], dx_spmd_info.first[1]},
{dx_spmd_info.second[0], dy_spmd_info.second[0]}};
} else {
// X'Y: dX = YG', dY = XG
dx_spmd_info =
Expand All @@ -355,6 +374,9 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x,
confirm_dist_attr_same_fn(dy_spmd_info.first[0], x, "trans x: dy-x");
confirm_dist_attr_same_fn(
dy_spmd_info.first[1], out_grad, "trans x: dy-out_grad");
return {
{dy_spmd_info.first[0], dx_spmd_info.first[0], dx_spmd_info.first[1]},
{dx_spmd_info.second[0], dy_spmd_info.second[0]}};
}
} else {
if (trans_y) {
Expand All @@ -369,25 +391,25 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x,
confirm_dist_attr_same_fn(
dy_spmd_info.first[0], out_grad, "trans y: dy-out_grad");
confirm_dist_attr_same_fn(dy_spmd_info.first[1], x, "trans y: dy-x");
return {
{dy_spmd_info.first[1], dx_spmd_info.first[1], dx_spmd_info.first[0]},
{dx_spmd_info.second[0], dy_spmd_info.second[0]}};
} else {
// XY: dX = GY', dY = X'G
dx_spmd_info =
MatmulInferSpmd(out_grad, y, /*trans_x=*/false, /*trans_y=*/true);
dy_spmd_info =
MatmulInferSpmd(x, out_grad, /*trans_x=*/true, /*trans_y=*/false);
confirm_dist_attr_same_fn(
dx_spmd_info.first[0], out_grad, "no trans: dx-out_grad");
confirm_dist_attr_same_fn(dx_spmd_info.first[1], y, "no trans: dx-y");
confirm_dist_attr_same_fn(dy_spmd_info.first[0], x, "no trans: dy-x");
confirm_dist_attr_same_fn(
dy_spmd_info.first[1], out_grad, "no trans: dy-out_grad");
confirm_dist_attr_with_arg_same_fn(dx_spmd_info.first[0],
dy_spmd_info.first[1],
"no trans: dy-out_grad");
return {
{dy_spmd_info.first[0], dx_spmd_info.first[1], dx_spmd_info.first[0]},
{dx_spmd_info.second[0], dy_spmd_info.second[0]}};
}
}

// Here we assume that the input dist attr is unchanged after inference,
// and only return the gradient dist attr
return {{x.dist_attr(), y.dist_attr(), out_grad.dist_attr()},
{dx_spmd_info.second[0], dy_spmd_info.second[0]}};
}

} // namespace distributed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@

class TestSemiAutoParallelHybridStrategy(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120, nnode=2)
super().setUp(
num_of_devices=2,
timeout=120,
nnode=2,
)
self._default_envs = {
"dtype": "float32",
"seed": "2023",
Expand Down

0 comments on commit bf3cd8c

Please sign in to comment.