-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Semi AutoParall] Support Partial Semantic I #55508
Changes from 80 commits
a47bf99
21b5a75
f7e39d7
c92992d
42a7b77
180edcc
ecbb1ae
f314b56
46153c7
1dcb80e
198bc1f
09d82a5
c3ea2a6
4cd1a2c
3631f06
ed0c31e
701d3fa
c1545a4
3ca2b73
747de08
968ce61
7be672d
3389b7e
3882a2c
3719a5a
73f49a8
aced5ea
ef92dc4
adcb470
43df148
27803af
5612da9
f9bd281
18f8d29
2628043
68a512a
f2b2edb
132558a
f3bc740
c11cdd2
491bf65
041abd4
60c90d3
d5d7557
44e9404
7657ee5
223f960
5dc1be3
3ce0e74
80f2a03
062970d
0cb4a9c
ab67ce1
f9675bd
7e31dea
694b310
2d4e938
f45eca8
e9b4ddc
43a4373
ad31f1b
9ca9969
dfad99d
fc3dfe6
def09f0
934cc61
daf098a
99a10f4
49257b9
fcf2ccb
4d2a854
6f7199a
7c69300
b785790
35b0446
ef7a4d6
d36137d
710a494
a284d8b
052f0df
e2c13e9
f73bc2c
ee903a8
17c54d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,20 +25,40 @@ limitations under the License. */ | |
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/utils.h" | ||
#include "paddle/phi/core/enforce.h" | ||
#include "paddle/utils/flat_hash_map.h" | ||
|
||
namespace phi { | ||
namespace distributed { | ||
namespace auto_parallel { | ||
|
||
constexpr const char* kDefault = "default"; | ||
|
||
enum class ReduceType : std::uint8_t { | ||
SUM = 0, | ||
AVG, | ||
MAX, | ||
MIN, | ||
PRODUCT, | ||
ANY, | ||
ALL | ||
}; | ||
constexpr const char* ReduceTypeStrings[] = { | ||
"SUM", "AVG", "MAX", "MIN", "PRODUCT", "ANY", "ALL"}; | ||
|
||
struct _Partial_ { | ||
std::string to_string() const; | ||
|
||
int64_t dim_{-1}; | ||
ReduceType type_{ReduceType::SUM}; | ||
}; | ||
|
||
class TensorDistAttr { | ||
public: | ||
TensorDistAttr() = default; | ||
|
||
explicit TensorDistAttr(const std::vector<int64_t>& tensor_shape); | ||
|
||
TensorDistAttr(const TensorDistAttr& tensor); | ||
TensorDistAttr(const TensorDistAttr& dist_attr); | ||
|
||
TensorDistAttr& operator=(const TensorDistAttr& dist_attr); | ||
|
||
|
@@ -52,6 +72,29 @@ class TensorDistAttr { | |
|
||
void set_dims_mapping(const std::vector<int64_t>& dims_mapping); | ||
|
||
// true if tensor is partial on any mesh dim. | ||
bool is_partial() const { return !partial_status_.empty(); } | ||
|
||
// return vector of mesh dims on which the this tensor is partial on | ||
const std::vector<int64_t> partial_dims() const; | ||
|
||
const paddle::flat_hash_map<int64_t, _Partial_>& partial_status() const { | ||
return partial_status_; | ||
} | ||
|
||
// by map | ||
void set_partial_status( | ||
const paddle::flat_hash_map<int64_t, _Partial_>& partial_status); | ||
|
||
// by each dim | ||
void set_partial_status(const std::vector<int64_t>& dims, | ||
const ReduceType& type = ReduceType::SUM); | ||
// all | ||
void clean_partial_status(); | ||
|
||
// clean by dims | ||
void clean_partial_dims(const std::vector<int64_t>& dims); | ||
|
||
void set_default_dims_mapping(const std::vector<int64_t>& tensor_shape); | ||
|
||
int64_t batch_dim() const { return batch_dim_; } | ||
|
@@ -89,11 +132,17 @@ class TensorDistAttr { | |
|
||
bool verify_annotated(const std::map<std::string, bool>& annotated) const; | ||
|
||
bool verify_partial_status() const; | ||
|
||
bool verify(const std::vector<int64_t>& tensor_shape) const; | ||
|
||
// TensorDistAttr from_string(const std::string& dist_str); | ||
std::string to_string() const; | ||
std::string partial_status_string() const; | ||
|
||
// in partial-support-stage-I partial will always be a runtime attribute, | ||
// there is not need to serialize it. support the partial serialization in | ||
// future partial-support-stage-II. | ||
void from_proto(const TensorDistAttrProto& proto); | ||
|
||
TensorDistAttrProto to_proto() const; | ||
|
@@ -109,6 +158,10 @@ class TensorDistAttr { | |
int64_t batch_dim_{0}; | ||
std::vector<bool> dynamic_dims_; | ||
std::map<std::string, bool> annotated_; | ||
// partial map would be small (less than mesh.size) | ||
// iterate operation (copy and comparision) would more frequency than random | ||
// element access. <key: dim on mesh, value: partial object> | ||
paddle::flat_hash_map<int64_t, _Partial_> partial_status_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why use a map structure here? If the "dim_" in Partial indicates the mesh dim, it seems unnecessary to store another mesh dim. In addition, if one tensor has only one reduce type, is it better to use a data structure like: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct ! |
||
}; | ||
|
||
inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) { | ||
|
@@ -122,6 +175,12 @@ inline bool operator!=(const TensorDistAttr& lhs, const TensorDistAttr& rhs) { | |
return !operator==(lhs, rhs); | ||
} | ||
|
||
bool operator==(const _Partial_& lhs, const _Partial_& rhs); | ||
|
||
inline bool operator!=(const _Partial_& lhs, const _Partial_& rhs) { | ||
return !operator==(lhs, rhs); | ||
} | ||
|
||
} // namespace auto_parallel | ||
} // namespace distributed | ||
} // namespace phi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ALL means?