-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Semi-Auto] SPMD Parallel Rule Base #53863
Changes from 40 commits
a47bf99
21b5a75
f7e39d7
c92992d
42a7b77
180edcc
ecbb1ae
f314b56
46153c7
1dcb80e
198bc1f
09d82a5
c3ea2a6
4cd1a2c
3631f06
ed0c31e
701d3fa
c1545a4
3ca2b73
747de08
968ce61
7be672d
3389b7e
3882a2c
3719a5a
73f49a8
aced5ea
ef92dc4
adcb470
43df148
27803af
5612da9
f9bd281
18f8d29
2628043
68a512a
f2b2edb
132558a
f3bc740
c11cdd2
491bf65
041abd4
60c90d3
0fc8ff9
eecd184
b70b3d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
cc_library( | ||
spmd_rule | ||
SRCS common.cc dist_tensor_spec.cc matmul_spmd_rule.cc | ||
DEPS op_dist_attr) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <glog/logging.h> | ||
|
||
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" | ||
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
namespace auto_parallel { | ||
|
||
std::vector<TensorDistAttr> SPMDRuleBase::InferForward( | ||
const std::vector<DistTensorSpec>& input_specs, | ||
const paddle::framework::AttributeMap& attrs) { | ||
PADDLE_THROW( | ||
phi::errors::Unimplemented("InferForward should be called from a " | ||
"derived class of SPMDRuleBase !")); | ||
} | ||
|
||
std::vector<TensorDistAttr> SPMDRuleBase::InferBackward( | ||
const std::vector<DistTensorSpec>& output_specs, | ||
const paddle::framework::AttributeMap& attrs) { | ||
PADDLE_THROW( | ||
phi::errors::Unimplemented("InferBackward should be called from a " | ||
"derived class of SPMDRuleBase !")); | ||
} | ||
|
||
std::unordered_map<std::string, int64_t> ShardingMergeForTensors( | ||
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>& | ||
tensor_axes_to_dim_pairs) { | ||
std::unordered_map<std::string, int64_t> axis_to_dim_map; | ||
std::unordered_map<int64_t, std::string> dim_to_axis_map; | ||
int64_t merge_dim; | ||
|
||
for (auto& pair : tensor_axes_to_dim_pairs) { | ||
for (size_t i = 0; i < pair.second.size(); ++i) { | ||
auto tensor_axis = pair.first.substr(i, 1); | ||
auto mesh_dim = pair.second[i]; | ||
|
||
if (axis_to_dim_map.count(tensor_axis) == 0) { | ||
merge_dim = mesh_dim; | ||
} else { | ||
merge_dim = ShardingMergeForAxis( | ||
tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]); | ||
} | ||
axis_to_dim_map[tensor_axis] = merge_dim; | ||
if (merge_dim != -1) { | ||
if (dim_to_axis_map.count(merge_dim) == 0) { | ||
dim_to_axis_map.insert({merge_dim, tensor_axis}); | ||
} else if (dim_to_axis_map[merge_dim].find(tensor_axis) == | ||
std::string::npos) { | ||
dim_to_axis_map[merge_dim] += tensor_axis; | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Resolute "mesh_dim shard by more than one axis" confict. | ||
// Now we just naive pick the first axis naively. | ||
// (TODO) use local cost model to pick the axis with lowest cost(in concern of | ||
// memory or communication or computation). | ||
for (auto& it : dim_to_axis_map) { | ||
if (it.second.size() > 1) { | ||
VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first | ||
<< "] are Sharding Multiple Tensor Axis: [" << it.second | ||
<< "]. The Axis: [" << it.second[0] << "] is Picked."; | ||
for (size_t i = 1; i < it.second.size(); ++i) { | ||
axis_to_dim_map[it.second.substr(i, 1)] = -1; | ||
} | ||
} | ||
} | ||
|
||
return axis_to_dim_map; | ||
} | ||
|
||
// Rule1: A repicated dimension could be merged by any sharded dimension. | ||
// Rule2: A tensor axis could at most be sharded by one mesh dimension. | ||
// (TODO trigger heuristics cost model and reshard to handle axis sharded by | ||
// multiple dimension case.) | ||
int64_t ShardingMergeForAxis(const std::string& axis, | ||
const int64_t& mesh_dim1, | ||
const int64_t& mesh_dim2) { | ||
if (mesh_dim1 != mesh_dim2) { | ||
if (mesh_dim1 == -1) { | ||
return mesh_dim2; | ||
} else if (mesh_dim2 == -1) { | ||
return mesh_dim1; | ||
} else { | ||
// (TODO) local cost model here. | ||
PADDLE_THROW( | ||
phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two " | ||
"different mesh dimension [%d] and [%d].", | ||
axis, | ||
mesh_dim1, | ||
mesh_dim2)); | ||
} | ||
|
||
} else { | ||
return mesh_dim1; | ||
} | ||
} | ||
|
||
TensorDistAttr CopyTensorDistAttrForOutput( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better add a copy constructor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copy constructor is not fit in this case since part of data member would be changed after "copy". |
||
const TensorDistAttr& src_dist_attr) { | ||
TensorDistAttr new_dist_attr = TensorDistAttr(); | ||
new_dist_attr.set_process_mesh(src_dist_attr.process_mesh()); | ||
new_dist_attr.set_batch_dim(src_dist_attr.batch_dim()); | ||
new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims()); | ||
// new_dist_attr.set_annotated(false); TODO unset field is false by default. | ||
return new_dist_attr; | ||
} | ||
|
||
std::vector<int64_t> ResoluteOutputPartialDimension( | ||
const std::unordered_map<std::string, int64_t>& axis_to_dim_map, | ||
const std::string& tensor_axes) { | ||
std::vector<int64_t> partial_on_dims; | ||
|
||
for (auto& it : axis_to_dim_map) { | ||
if (tensor_axes.find(it.first) == std::string::npos) { | ||
if (it.second > -1) { | ||
partial_on_dims.push_back(it.second); | ||
} | ||
} | ||
} | ||
return partial_on_dims; | ||
} | ||
|
||
std::string GetBroadcastAxes(const int64_t& tenosr_ndim, | ||
const int64_t& broadcast_ndim, | ||
const std::string& alphabet) { | ||
PADDLE_ENFORCE_GE( | ||
alphabet.size(), | ||
broadcast_ndim, | ||
phi::errors::InvalidArgument( | ||
"size of alphabet [%d] is less than broadcast ndim [%d]", | ||
alphabet.size(), | ||
broadcast_ndim)); | ||
PADDLE_ENFORCE_GE(broadcast_ndim, | ||
tenosr_ndim, | ||
phi::errors::InvalidArgument( | ||
"broadcast ndim [%d] is less than tenosr ndim [%d]", | ||
broadcast_ndim, | ||
tenosr_ndim)); | ||
if (tenosr_ndim <= 0) { | ||
return std::string(); | ||
} | ||
return alphabet.substr(broadcast_ndim - tenosr_ndim, tenosr_ndim); | ||
} | ||
|
||
// SPMDRuleMap | ||
SPMDRuleMap& SPMDRuleMap::Instance() { | ||
static SPMDRuleMap g_spmd_rule_map; | ||
return g_spmd_rule_map; | ||
} | ||
|
||
// To enable default replicated spmd rule for op that are NOT registered | ||
// which all tensors of inputs and outputs will be replicated in all ranks of | ||
// the mesh. | ||
SPMDRuleBase* SPMDRuleMap::Get(const std::string& op_type) const { | ||
auto rule_ptr = GetNullable(op_type); | ||
if (rule_ptr == nullptr) { | ||
std::string str; | ||
for (const auto& item : map_) { | ||
str += item.first + ", "; | ||
} | ||
VLOG(4) << "Size of current map [" << map_.size() << "]"; | ||
VLOG(4) << "Keys are [" << str << "]"; | ||
} | ||
PADDLE_ENFORCE_NOT_NULL( | ||
rule_ptr, | ||
platform::errors::NotFound( | ||
"NO SPMD Rule has been registered for Operator [%s].", op_type)); | ||
return rule_ptr; | ||
} | ||
|
||
SPMDRuleBase* SPMDRuleMap::GetNullable(const std::string& op_type) const { | ||
auto it = map_.find(op_type); | ||
if (it == map_.end()) { | ||
return nullptr; | ||
} else { | ||
return it->second.get(); | ||
} | ||
} | ||
|
||
int SPMDRuleMap::Insert(const std::string& op_type, | ||
std::unique_ptr<SPMDRuleBase> rule) { | ||
VLOG(4) << "Call SPMDRuleMap::Insert!"; | ||
PADDLE_ENFORCE_NE( | ||
Has(op_type), | ||
true, | ||
platform::errors::AlreadyExists( | ||
"SPMD Rule for Operator [%s] has been registered.", op_type)); | ||
map_.insert({op_type, std::move(rule)}); | ||
|
||
return 1; | ||
} | ||
|
||
} // namespace auto_parallel | ||
} // namespace distributed | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <iterator> | ||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "paddle/fluid/framework/attribute.h" | ||
// #include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/type_defs.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" | ||
#include "paddle/utils/flat_hash_map.h" | ||
|
||
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
namespace auto_parallel { | ||
|
||
using paddle::framework::Attribute; | ||
|
||
class SPMDRuleBase { | ||
public: | ||
virtual ~SPMDRuleBase() {} | ||
|
||
// Merge the DistAttr of input tensors and infer the DistAttr of the output | ||
// tensors from the merged input information. The input are DistAttr and Shape | ||
// (wrapp as DistTensorSpec) of the input tensors (tensors follow the same | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wrapp -> Wrapped There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
// order defined in Op's Phi API) and Op Attribue of the current op. The ouput | ||
// are the Merged DistAttr of input tensors and the infered DistAttr of the | ||
// output tensors. The Merged DistAttr might be different from the original | ||
// Intput DistAttrs, which means that the corressponding input tensor need to | ||
// be reshard. | ||
virtual std::vector<TensorDistAttr> InferForward( | ||
const std::vector<DistTensorSpec>& input_specs, | ||
const paddle::framework::AttributeMap& attrs); | ||
|
||
// Merge the DistAttr of output tensors and infer the DistAttr of the input | ||
// tensors from the merged output information. The input are DistAttr and | ||
// Shape (wrapp as DistTensorSpec) of the input tensors and Op Attribue of the | ||
// current op. The ouput are the Merged DistAttr of output tensors and the | ||
// infered DistAttr of the input tensors. This function will be use in Static | ||
// Graph mode only, where we have the whole computation graph for sharding | ||
// propogation. | ||
virtual std::vector<TensorDistAttr> InferBackward( | ||
const std::vector<DistTensorSpec>& output_specs, | ||
const paddle::framework::AttributeMap& attrs); | ||
|
||
template <typename T> | ||
inline const T ExtractAttr( | ||
const std::string& name, | ||
const paddle::framework::AttributeMap& attrs) const { | ||
auto& attr = GetAttr(name, attrs); | ||
|
||
// In order to get bool attr properly | ||
framework::proto::AttrType attr_type = | ||
static_cast<framework::proto::AttrType>(attr.index() - 1); | ||
if (attr_type == framework::proto::AttrType::INT) { | ||
if (std::is_same<bool, T>::value) { | ||
return static_cast<bool>(PADDLE_GET_CONST(int, attr)); | ||
} | ||
} | ||
|
||
return PADDLE_GET_CONST(T, attr); | ||
} | ||
|
||
const Attribute& GetAttr(const std::string& name, | ||
const paddle::framework::AttributeMap& attrs) const { | ||
auto iter = attrs.find(name); | ||
PADDLE_ENFORCE_NE(iter, | ||
attrs.end(), | ||
paddle::platform::errors::NotFound( | ||
"(%s) is not found in AttributeMap.")); | ||
return iter->second; | ||
} | ||
}; | ||
|
||
// Merge sharding specification (dims mapping) of given tensors. | ||
// The same axes of different tensors will be merged. | ||
std::unordered_map<std::string, int64_t> ShardingMergeForTensors( | ||
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>& | ||
tensor_axes_to_dim_pairs); | ||
|
||
// Merge the sharding specification (dims mapping) for one tensor Axis. | ||
// Rule1: A repicated dimension could be merged by any sharded dimension. | ||
// Rule2: A tensor axis could at most be sharded by one mesh dimension. | ||
// (TODO trigger heuristics cost model and reshard to handle axis sharded by | ||
// multiple dimension case.) | ||
int64_t ShardingMergeForAxis(const std::string& axis, | ||
const int64_t& mesh_dim1, | ||
const int64_t& mesh_dim2); | ||
|
||
TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr); | ||
|
||
// Resolute the partial mesh dimension of a output tensor, giving the | ||
// merged sharding specifcation of input tensors and the axis names of output | ||
// tensor. Input are | ||
std::vector<int64_t> ResoluteOutputPartialDimension( | ||
const std::unordered_map<std::string, int64_t>& axis_to_dim_map, | ||
const std::string& tensor_axes); | ||
|
||
// Generate the axis notation of tensor for the einsum notation of a broadcast | ||
// operation(alignment star from the rightmost axis). tenosr_ndim: the size of | ||
// the tensor. broadcast_ndim: the maxium size of tensors in this broadcast | ||
// operation. alphabet: the characters used to represent the axes of tensor. | ||
// length of alphabet should >= broadcast_ndim. | ||
std::string GetBroadcastAxes(const int64_t& tenosr_ndim, | ||
const int64_t& broadcast_ndim, | ||
const std::string& alphabet); | ||
|
||
// The static map that stores and initializes all the registered SPMD rules. | ||
class SPMDRuleMap { | ||
public: | ||
~SPMDRuleMap() = default; | ||
|
||
// A singleton | ||
static SPMDRuleMap& Instance(); | ||
|
||
// Returns the spmd rule for the given op_type | ||
SPMDRuleBase* Get(const std::string& op_type) const; | ||
|
||
// Returns the spmd by name or nullptr if not registered | ||
SPMDRuleBase* GetNullable(const std::string& op_type) const; | ||
|
||
// Register a spmd for an op_type. | ||
int Insert(const std::string& op_type, std::unique_ptr<SPMDRuleBase> rule); | ||
|
||
bool Has(const std::string& op_type) const { | ||
return map_.find(op_type) != map_.end(); | ||
} | ||
|
||
private: | ||
SPMDRuleMap() = default; | ||
paddle::flat_hash_map<std::string, std::unique_ptr<SPMDRuleBase>> map_; | ||
DISABLE_COPY_AND_ASSIGN(SPMDRuleMap); | ||
}; | ||
|
||
#define REGISTER_SPMD_RULE(op_type, rule_class, ...) \ | ||
UNUSED static int __spmd_rule_holder_##op_type = \ | ||
::paddle::distributed::auto_parallel::SPMDRuleMap::Instance().Insert( \ | ||
#op_type, std::make_unique<rule_class>(__VA_ARGS__)) | ||
|
||
} // namespace auto_parallel | ||
} // namespace distributed | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put the related header first.