Skip to content

Commit

Permalink
adapt_phi/embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuzyc committed Sep 22, 2023
1 parent 579103d commit d085f59
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 130 deletions.
5 changes: 0 additions & 5 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
Expand All @@ -31,10 +30,6 @@ namespace auto_parallel {
// replicated rule
REGISTER_SPMD_RULE(replicated, ReplicatedSPMDRule);

// embedding rule
REGISTER_SPMD_RULE(embedding, EmbeddingSPMDRule);
REGISTER_SPMD_RULE(lookup_table_v2, EmbeddingSPMDRule);

// softmax rule
REGISTER_SPMD_RULE(softmax, SoftmaxSPMDRule);
REGISTER_SPMD_RULE(log_softmax, SoftmaxSPMDRule);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ AttrType InferSpmdContext::AttrAt(size_t idx) const {

template float InferSpmdContext::AttrAt(size_t idx) const;
template int InferSpmdContext::AttrAt(size_t idx) const;
template int64_t InferSpmdContext::AttrAt(size_t idx) const;

template <>
bool InferSpmdContext::AttrAt(size_t idx) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ struct InferSpmdFnImpl<Return (*)(Args...), infer_spmd_fn> {
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(bool);
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(int);
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(float);
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(int64_t);
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<int64_t>);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,56 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h"
#include "paddle/phi/infermeta/spmd_rules/embedding.h"

namespace paddle {
#include "glog/logging.h"

#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/infermeta/spmd_rules/utils.h"

namespace phi {
namespace distributed {
namespace auto_parallel {

using phi::distributed::auto_parallel::str_join;

// step0: verify input args based on embedding logic
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
EmbeddingSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
auto input_specs_size = input_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
2,
phi::errors::InvalidArgument(
"The size of InputSpec of embedding should be 2, but got [%d].",
input_specs_size));
auto x_shape = input_specs[0].shape();
auto weight_shape = input_specs[1].shape();
SpmdInfo EmbeddingInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& weight,
int padding_idx,
bool sparse) {
// Step0: Verify input args based on embedding logic
auto x_shape = phi::vectorize(x.dims());
auto weight_shape = phi::vectorize(weight.dims());
int x_ndim = static_cast<int>(x_shape.size());
int weight_ndim = static_cast<int>(weight_shape.size());
auto x_dist_attr_src = input_specs[0].dist_attr();
auto weight_dist_attr_src = input_specs[1].dist_attr();
auto x_dist_attr_src = x.dist_attr();
auto weight_dist_attr_src = weight.dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
std::vector<int64_t> weight_dims_mapping =
weight_dist_attr_src.dims_mapping();

PADDLE_ENFORCE_EQ(
x_ndim,
x_dims_mapping.size(),
phi::errors::InvalidArgument(
"Mismatch of X's tensor size: [%d] and X's dims_mapping size [%d].",
x_ndim,
x_dims_mapping.size()));
phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's "
"dims_mapping size [%d] are not matched.",
x_ndim,
x_dims_mapping.size()));
PADDLE_ENFORCE_EQ(
weight_ndim,
weight_dims_mapping.size(),
phi::errors::InvalidArgument(
"Mismatch of W's tensor size: [%d] and W's dims_mapping size [%d].",
weight_ndim,
weight_dims_mapping.size()));
phi::errors::InvalidArgument("Tensor W's tensor rank [%d] and W's "
"dims_mapping size [%d] are not matched.",
weight_ndim,
weight_dims_mapping.size()));
PADDLE_ENFORCE_EQ(
weight_ndim,
2,
phi::errors::InvalidArgument("Embedding table should have TWO dimension, "
"but got a tensor with [%d] dimension.",
weight_ndim));

int64_t padding_idx = ExtractAttr<int64_t>("padding_idx", attrs);
bool sparse = ExtractAttr<bool>("sparse", attrs);

// determine parallel mode
int64_t weight_row_axis_mapping = weight_dims_mapping[0];

Expand Down Expand Up @@ -103,35 +100,25 @@ EmbeddingSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
<< "sparse: "
<< "[" << (sparse ? "true" : "false") << "]; ";

// step1: build Einsum Notation
// Step1: Build Einsum Notation
std::string alphabet = "abcdefghilmnopqrstuvwxyz";
std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet);
std::string weight_axes = "jk";
std::string out_axes = x_axes + "k";

// step2: Sharding Propogation
// Step2: Sharding Propogation
// Step2.1: merge input shardings
auto axis_to_dim_map = ShardingMergeForTensors(
{{x_axes, x_dims_mapping}, {weight_axes, weight_dims_mapping}}, false);

// step3: Infer Output's Dims Mapping.
TensorDistAttr output_dist_attr_dst =
CopyTensorDistAttrForOutput(x_dist_attr_src);
std::vector<int64_t> out_dims_mapping;
out_dims_mapping.reserve(out_axes.size());
for (size_t i = 0; i < out_axes.size(); ++i) {
out_dims_mapping.push_back(axis_to_dim_map[out_axes.substr(i, 1)]);
}
output_dist_attr_dst.set_dims_mapping(out_dims_mapping);

// step3.1: Handle Partial
// (TODO) support case where embedding table is partial at very beginning.
std::vector<int64_t> partial_on_dims;
if (weight_row_axis_mapping > -1) {
partial_on_dims.push_back(weight_row_axis_mapping);
}
output_dist_attr_dst.set_partial_status(partial_on_dims);
// Step2.2: infer output's dims mapping.
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
std::vector<int64_t> out_dims_mapping =
GetDimsMappingForAxes(out_axes, axis_to_dim_map);
out_dist_attr.set_dims_mapping(out_dims_mapping);

// step4: merge potential conflict in inputs
// Step2.3: merge potential conflict in inputs,
// update input dims mapping with merged shardings.
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(x_axes, axis_to_dim_map));
Expand All @@ -140,38 +127,39 @@ EmbeddingSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
weight_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(weight_axes, axis_to_dim_map));

VLOG(4) << "EmbeddingSPMDRule InferForward: "
// Step3: Handle Partial
// (TODO) support case where embedding table is partial at very beginning.
std::vector<int64_t> partial_on_dims;
if (weight_row_axis_mapping > -1) {
partial_on_dims.push_back(weight_row_axis_mapping);
}
out_dist_attr.set_partial_status(partial_on_dims);

VLOG(4) << "EmbeddingInferSpmd:\n"
<< "Einsum notation: [" << x_axes << "," << weight_axes << " --> "
<< out_axes << "]. " << std::endl
<< "X shape: [" << str_join(x_shape) << "], src_dims_mapping: ["
<< str_join(x_dims_mapping) << "], dst_dims_mapping: ["
<< str_join(x_dist_attr_dst.dims_mapping()) << "]; Y shape: ["
<< str_join(x_dist_attr_dst.dims_mapping()) << "]\n W shape: ["
<< str_join(weight_shape) << "], src_dims_mapping: ["
<< str_join(weight_dims_mapping) << "], dst_dims_mapping: ["
<< str_join(weight_dist_attr_dst.dims_mapping())
<< "]; Output dims_mapping: [" << str_join(out_dims_mapping)
<< "], partial_on_dims: [" << str_join(partial_on_dims) << "]";
<< "]\n Out dims_mapping: [" << str_join(out_dims_mapping)
<< "], partial_on_dims: [" << str_join(partial_on_dims) << "]\n\n";

return {{x_dist_attr_dst, weight_dist_attr_dst}, {output_dist_attr_dst}};
return {{x_dist_attr_dst, weight_dist_attr_dst}, {out_dist_attr}};
}

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
EmbeddingSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
SpmdInfo EmbeddingInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& weight,
const DistMetaTensor& out,
int padding_idx,
bool sparse) {
// Step0: Verify input args based on embedding logic
// InferBackward is called after InferForward, so we skip some checks.
auto output_specs_size = output_specs.size();
PADDLE_ENFORCE_EQ(
output_specs_size,
1,
phi::errors::InvalidArgument(
"The size of OutputSpec of embedding should be 1, but got [%d].",
output_specs_size));

auto x_shape = input_specs[0].shape();
auto x_shape = phi::vectorize(x.dims());
int x_ndim = static_cast<int>(x_shape.size());
auto out_shape = output_specs[0].shape();
auto out_shape = phi::vectorize(out.dims());
int out_ndim = static_cast<int>(out_shape.size());

PADDLE_ENFORCE_EQ(x_ndim,
Expand All @@ -182,10 +170,10 @@ EmbeddingSPMDRule::InferBackward(
x_ndim,
out_ndim));

auto out_dist_attr_src = output_specs[0].dist_attr();
auto out_dist_attr_src = out.dist_attr();
std::vector<int64_t> out_dims_mapping = out_dist_attr_src.dims_mapping();

// step1: build Einsum Notation
// Step1: Build Einsum Notation
std::string alphabet = "abcdefghilmnopqrstuvwxyz";
std::string x_axes = GetBroadcastAxes(out_ndim - 1, out_ndim - 1, alphabet);
std::string weight_axes = "jk";
Expand All @@ -195,32 +183,30 @@ EmbeddingSPMDRule::InferBackward(
// should not use input dims mapping for backward sharding merge
auto axis_to_dim_map =
ShardingMergeForTensors({{out_axes, out_dims_mapping}}, false);
TensorDistAttr x_dist_attr_dst =
CopyTensorDistAttrForOutput(input_specs[0].dist_attr());
x_dist_attr_dst.set_dims_mapping(GetDimsMappingForAxes(
TensorDistAttr x_dist_attr = CopyTensorDistAttrForOutput(x.dist_attr());
x_dist_attr.set_dims_mapping(GetDimsMappingForAxes(
x_axes, axis_to_dim_map, /*unsharded_miss_axis=*/true));
TensorDistAttr weight_dist_attr_dst =
CopyTensorDistAttrForOutput(input_specs[1].dist_attr());
weight_dist_attr_dst.set_dims_mapping(GetDimsMappingForAxes(
TensorDistAttr weight_dist_attr =
CopyTensorDistAttrForOutput(weight.dist_attr());
weight_dist_attr.set_dims_mapping(GetDimsMappingForAxes(
weight_axes, axis_to_dim_map, /*unsharded_miss_axis=*/true));

// step3: Handle Partial
// NOTE we skip the partial backward inference in Partial Stage-I.
// output partial --> weight sharded on first axis.

VLOG(4) << "EmbeddingSPMDRule InferBackward: "
VLOG(4) << "EmbeddingInferSpmdReverse:\n"
<< "Einsum notation: [" << x_axes << "," << weight_axes << " --> "
<< out_axes << "]. " << std::endl
<< "Out shape: [" << str_join(out_shape) << "], src_dims_mapping: ["
<< str_join(out_dims_mapping) << "], dst_dims_mapping: ["
<< str_join(out_dims_mapping) << "]; Input X dims_mapping: ["
<< str_join(x_dist_attr_dst.dims_mapping())
<< "], Input Weight dims_mapping:["
<< str_join(weight_dist_attr_dst.dims_mapping()) << "].";
<< str_join(out_dims_mapping) << "]\n Input X dims_mapping: ["
<< str_join(x_dist_attr.dims_mapping())
<< "]\n Input Weight dims_mapping:["
<< str_join(weight_dist_attr.dims_mapping()) << "]\n\n";

return {{x_dist_attr_dst, weight_dist_attr_dst}, {out_dist_attr_src}};
return {{x_dist_attr, weight_dist_attr}, {out_dist_attr_src}};
}

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,28 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {

// (TODO) Support 3 parallel cases for embedding:
// 1. Batch dimensions of input ids is sharded on mesh.
// 2. Row-wise Parallel of embedding table. (NOTE: Row-wise Parallel need to
// change the embedding kernel for miss ids.)
// 3. Column-wise Parallel of embedding table.
// 4. Hybrid Parallelism of above 3 cases.
class EmbeddingSPMDRule : public SPMDRuleBase {
public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
SpmdInfo EmbeddingInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& weight,
int padding_idx,
bool sparse);

SpmdInfo EmbeddingInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& weight,
const DistMetaTensor& out,
int padding_idx,
bool sparse);

} // namespace distributed
} // namespace paddle
} // namespace phi
11 changes: 11 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h"
#include "paddle/phi/infermeta/spmd_rules/elementwise.h"
#include "paddle/phi/infermeta/spmd_rules/layer_norm.h"
#include "paddle/phi/infermeta/spmd_rules/embedding.h"
#include "paddle/phi/infermeta/spmd_rules/matmul.h"
#include "paddle/phi/infermeta/spmd_rules/reduction.h"
#include "paddle/phi/infermeta/spmd_rules/replicated.h"
Expand Down Expand Up @@ -464,5 +465,15 @@ PD_REGISTER_SPMD_RULE(
PD_INFER_SPMD(phi::distributed::LayerNormInferSpmd),
PD_INFER_SPMD(phi::distributed::LayerNormInferSpmdReverse));

// embedding rule
PD_REGISTER_SPMD_RULE(
embedding,
PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd),
PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse));
PD_REGISTER_SPMD_RULE(
lookup_table_v2,
PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd),
PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse));

} // namespace distributed
} // namespace phi
Loading

0 comments on commit d085f59

Please sign in to comment.