Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi AutoParall] Support Partial Semantic I #55508

Merged
merged 84 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from 80 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
a47bf99
base rule
JZ-LIANG May 16, 2023
21b5a75
add sharidng merge
JZ-LIANG May 18, 2023
f7e39d7
add sharidng axis merge
JZ-LIANG May 19, 2023
c92992d
define unified data class for inferencing dist_attr
pkuzyc May 18, 2023
42a7b77
test wrap DistTensorSpec in dygraph mode
pkuzyc May 19, 2023
180edcc
matmul main logic done
JZ-LIANG May 23, 2023
ecbb1ae
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JZ-LIANG May 23, 2023
f314b56
Merge remote-tracking branch 'zyc/develop' into semi-auto/rule-base
JZ-LIANG May 23, 2023
46153c7
shape int64
JZ-LIANG May 23, 2023
1dcb80e
common cc
JZ-LIANG May 23, 2023
198bc1f
define unified data class for inferencing dist_attr
pkuzyc May 18, 2023
09d82a5
test wrap DistTensorSpec in dygraph mode
pkuzyc May 19, 2023
c3ea2a6
define python api and wrap function in static mode for DistTensorSpec
pkuzyc May 23, 2023
4cd1a2c
revise syntax
JZ-LIANG May 24, 2023
3631f06
Merge remote-tracking branch 'zyc/develop' into semi-auto/rule-base
JZ-LIANG May 24, 2023
ed0c31e
map bugfix
JZ-LIANG May 29, 2023
701d3fa
broadcast func
JZ-LIANG May 29, 2023
c1545a4
compile 1
JZ-LIANG May 29, 2023
3ca2b73
add unitest
JZ-LIANG May 31, 2023
747de08
add registry
JZ-LIANG Jun 6, 2023
968ce61
Merge branch 'semi-auto/rule-base' of https://github.com/JZ-LIANG/Pad…
JZ-LIANG Jun 6, 2023
7be672d
update unitest
JZ-LIANG Jun 6, 2023
3389b7e
bugfix
JZ-LIANG Jun 6, 2023
3882a2c
bugfix
JZ-LIANG Jun 6, 2023
3719a5a
add pybind
JZ-LIANG Jun 6, 2023
73f49a8
bugfix
JZ-LIANG Jun 6, 2023
aced5ea
bugfix macro gloabl name space
JZ-LIANG Jun 6, 2023
ef92dc4
bugfix macro gloabl name space
JZ-LIANG Jun 6, 2023
adcb470
segment fault
JZ-LIANG Jun 8, 2023
43df148
pybind
JZ-LIANG Jun 8, 2023
27803af
pybind test
JZ-LIANG Jun 8, 2023
5612da9
pybind bugfixed1
JZ-LIANG Jun 14, 2023
f9bd281
pybind bugfixed2
JZ-LIANG Jun 14, 2023
18f8d29
pybind unitest
JZ-LIANG Jun 14, 2023
2628043
Merge remote-tracking branch 'upstream/develop' into semi-auto/rule-base
JZ-LIANG Jun 16, 2023
68a512a
merge dev
JZ-LIANG Jun 16, 2023
f2b2edb
merge dev
JZ-LIANG Jun 16, 2023
132558a
merge dev
JZ-LIANG Jun 16, 2023
f3bc740
fixed cmake conflict
JZ-LIANG Jun 16, 2023
c11cdd2
fixed cmake conflict
JZ-LIANG Jun 16, 2023
491bf65
rename get method
JZ-LIANG Jun 20, 2023
041abd4
revise inferforward output type
JZ-LIANG Jun 20, 2023
60c90d3
revise comment
JZ-LIANG Jun 20, 2023
d5d7557
replicated rule
JZ-LIANG Jun 21, 2023
44e9404
replicated rule 2
JZ-LIANG Jun 21, 2023
7657ee5
revert bug deps
JZ-LIANG Jun 27, 2023
223f960
Merge branch 'semi-auto/revert-phi-dep' into semi-auto/replicated-rule
JZ-LIANG Jun 27, 2023
5dc1be3
add rule
JZ-LIANG Jun 28, 2023
3ce0e74
add unitest
JZ-LIANG Jun 28, 2023
80f2a03
add rule
JZ-LIANG Jun 29, 2023
062970d
add unitest
JZ-LIANG Jul 6, 2023
0cb4a9c
move ut of auto_parallel
zhiqiu Jul 6, 2023
ab67ce1
fix ut
zhiqiu Jul 7, 2023
f9675bd
Merge remote-tracking branch 'upstream/develop' into semi-auto/embedd…
JZ-LIANG Jul 7, 2023
7e31dea
Merge branch 'dev/mv_ut' of https://github.com/zhiqiu/Paddle into sem…
JZ-LIANG Jul 7, 2023
694b310
Merge branch 'semi-auto/embedding-rule' into semi-auto/softmax-rule
JZ-LIANG Jul 7, 2023
2d4e938
bugfix
JZ-LIANG Jul 7, 2023
f45eca8
bugfix
JZ-LIANG Jul 7, 2023
e9b4ddc
bugfix
JZ-LIANG Jul 7, 2023
43a4373
bugfix
JZ-LIANG Jul 7, 2023
ad31f1b
bugfix
JZ-LIANG Jul 7, 2023
9ca9969
bugfix
JZ-LIANG Jul 7, 2023
dfad99d
bugfix
JZ-LIANG Jul 7, 2023
fc3dfe6
Merge remote-tracking branch 'upstream/develop' into semi-auto/embedd…
JZ-LIANG Jul 10, 2023
def09f0
Merge branch 'semi-auto/embedding-rule' into semi-auto/softmax-rule
JZ-LIANG Jul 10, 2023
934cc61
resolute input sharding conflict maybe
JZ-LIANG Jul 11, 2023
daf098a
Merge branch 'semi-auto/embedding-rule' into semi-auto/softmax-rule
JZ-LIANG Jul 11, 2023
99a10f4
fixed comment
JZ-LIANG Jul 12, 2023
49257b9
Merge remote-tracking branch 'upstream/develop' into semi-auto/entrop…
JZ-LIANG Jul 12, 2023
fcf2ccb
add rule
JZ-LIANG Jul 12, 2023
4d2a854
Merge remote-tracking branch 'upstream/develop' into semi-auto/entrop…
JZ-LIANG Jul 13, 2023
6f7199a
add unitest
JZ-LIANG Jul 13, 2023
7c69300
Merge remote-tracking branch 'upstream/develop' into semi-auto/entrop…
JZ-LIANG Jul 17, 2023
b785790
add partial for distattr
JZ-LIANG Jul 18, 2023
35b0446
bugfix
JZ-LIANG Jul 19, 2023
ef7a4d6
Merge remote-tracking branch 'upstream/develop' into semi-auto/partial-I
JZ-LIANG Jul 20, 2023
d36137d
Merge remote-tracking branch 'upstream/develop' into semi-auto/partial-I
JZ-LIANG Jul 24, 2023
710a494
set --> map
JZ-LIANG Jul 24, 2023
a284d8b
pybind & unitest
JZ-LIANG Jul 25, 2023
052f0df
internal api
JZ-LIANG Jul 25, 2023
e2c13e9
partial status to map
JZ-LIANG Jul 26, 2023
f73bc2c
bugfix
JZ-LIANG Jul 26, 2023
ee903a8
update unitest
JZ-LIANG Jul 31, 2023
17c54d1
bugfix for partial set
JZ-LIANG Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
// Step2.3.1 Output Partial
std::vector<int64_t> partial_on_dims =
ResoluteOutputPartialDimension(axis_to_dim_map, out_axes);
output_dist_attr_dst.set_partial_status(partial_on_dims);

// Step2.3.2 handle input tensor partial (TODO)
VLOG(4) << "MatmulSPMDRule InferForward: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,15 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
CopyTensorDistAttrForOutput(input_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(output_dims_mapping);

std::vector<TensorDistAttr> output_dist_attrs;
output_dist_attrs.emplace_back(output_dist_attr);

// step2.4: handle partial
// Step2.4.1 Output Partial
std::vector<int64_t> partial_on_dims =
ResoluteOutputPartialDimension(axis_to_dim_map, output_axes);
output_dist_attr.set_partial_status(
partial_on_dims /*, handle reduce_type in future */);

std::vector<TensorDistAttr> output_dist_attrs;
output_dist_attrs.emplace_back(output_dist_attr);

// Step2.4.2 handle input tensor partial (TODO)
// If the op is a linear op, i.e. `linearity` is true, it supports
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,11 @@ void BindAutoParallel(py::module *m) {
return TensorDistAttr(self);
},
py::arg("memo"))
.def("__str__", &TensorDistAttr::to_string);
.def("__str__", &TensorDistAttr::to_string)
.def("_is_partial", &TensorDistAttr::is_partial)
.def("_partial_dims", &TensorDistAttr::partial_dims)
.def("_clean_partial_dims", &TensorDistAttr::clean_partial_dims)
.def("_clean_partial_status", &TensorDistAttr::clean_partial_status);

py::class_<SPMDRuleBase>(*m, "SPMDRuleBase")
.def("infer_forward", &SPMDRuleBase::InferForward)
Expand Down
93 changes: 92 additions & 1 deletion paddle/phi/core/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace phi {
namespace distributed {
namespace auto_parallel {

// partial is not allow annotated by user by now.
std::vector<std::string> TensorDistAttr::fields_{
"process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"};

Expand All @@ -44,6 +45,7 @@ TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
std::swap(this->batch_dim_, tmp.batch_dim_);
std::swap(this->dynamic_dims_, tmp.dynamic_dims_);
std::swap(this->annotated_, tmp.annotated_);
std::swap(this->partial_status_, tmp.partial_status_);
return *this;
}

Expand All @@ -53,6 +55,7 @@ void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
set_partial_status(dist_attr.partial_status());
}

void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
Expand All @@ -77,6 +80,46 @@ void TensorDistAttr::set_annotated(
annotated_ = annotated;
}

const std::vector<int64_t> TensorDistAttr::partial_dims() const {
std::vector<int64_t> keys;
keys.reserve(partial_status_.size());
for (auto& kv : partial_status_) {
keys.push_back(kv.first);
}
return keys;
}

void TensorDistAttr::set_partial_status(
const paddle::flat_hash_map<int64_t, _Partial_>& partial_status) {
partial_status_ = partial_status;
}

void TensorDistAttr::set_partial_status(const std::vector<int64_t>& dims,
const ReduceType& type) {
for (const auto& dim : dims) {
if (partial_status_.count(dim) != 0) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Trying to Set dim %d as Partial which is already a Partial dim.",
dim));
}
_Partial_ partial = {dim, type};
partial_status_.emplace(dim, partial);
}
}

void TensorDistAttr::clean_partial_status() { partial_status_.clear(); }

void TensorDistAttr::clean_partial_dims(const std::vector<int64_t>& dims) {
for (const auto& dim : dims) {
if (partial_status_.count(dim) == 0) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Trying to clean Partial on dim %d but it is not Partial.", dim));
} else {
partial_status_.erase(dim);
}
}
}

void TensorDistAttr::set_default_dims_mapping(
const std::vector<int64_t>& tensor_shape) {
if (!tensor_shape.empty()) {
Expand Down Expand Up @@ -178,6 +221,21 @@ bool TensorDistAttr::verify_annotated(
return true;
}

bool TensorDistAttr::verify_partial_status() const {
VLOG(4) << "[TensorDistAttr verify_partial_status] "
<< partial_status_string();
for (auto& itr : partial_status_) {
if (itr.second.dim_ < 0 || itr.second.dim_ >= process_mesh_.ndim()) {
return false;
}
if (itr.second.type_ < ReduceType::SUM ||
itr.second.type_ <= ReduceType::ALL) {
return false;
}
}
return true;
}

bool TensorDistAttr::verify(const std::vector<int64_t>& tensor_shape) const {
if (!verify_process_mesh(process_mesh_)) {
return false;
Expand All @@ -194,6 +252,9 @@ bool TensorDistAttr::verify(const std::vector<int64_t>& tensor_shape) const {
if (!verify_annotated(annotated_)) {
return false;
}
if (!verify_partial_status()) {
return false;
}
return true;
}

Expand All @@ -203,7 +264,8 @@ std::string TensorDistAttr::to_string() const {
dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], ";
dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", ";
dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], ";
dist_str += "annotated: [" + str_join(annotated_) + "]}";
dist_str += "annotated: [" + str_join(annotated_) + "], ";
dist_str += "partial: " + partial_status_string() + ".}";
return dist_str;
}

Expand Down Expand Up @@ -254,6 +316,16 @@ void TensorDistAttr::parse_from_string(const std::string& data) {
from_proto(proto);
}

bool operator==(const _Partial_& lhs, const _Partial_& rhs) {
if (lhs.dim_ != rhs.dim_) {
return false;
}
if (lhs.type_ != rhs.type_) {
return false;
}
return true;
}

bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) {
return false;
Expand All @@ -267,9 +339,28 @@ bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.dynamic_dims() != rhs.dynamic_dims()) {
return false;
}
if (lhs.partial_status() != rhs.partial_status()) {
return false;
}
return true;
}

std::string TensorDistAttr::partial_status_string() const {
std::string partial_status_str = "[";
for (auto& itr : partial_status_) {
partial_status_str += itr.second.to_string() + ", ";
}
partial_status_str += "]";
return partial_status_str;
}

std::string _Partial_::to_string() const {
std::string partial_str = "";
partial_str = "Partial(dims:" + std::to_string(dim_) + ", " +
ReduceTypeStrings[static_cast<int>(type_)] + ")";
return partial_str;
}

} // namespace auto_parallel
} // namespace distributed
} // namespace phi
61 changes: 60 additions & 1 deletion paddle/phi/core/distributed/auto_parallel/dist_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,40 @@ limitations under the License. */
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/flat_hash_map.h"

namespace phi {
namespace distributed {
namespace auto_parallel {

constexpr const char* kDefault = "default";

enum class ReduceType : std::uint8_t {
SUM = 0,
AVG,
MAX,
MIN,
PRODUCT,
ANY,
ALL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALL means?

};
constexpr const char* ReduceTypeStrings[] = {
"SUM", "AVG", "MAX", "MIN", "PRODUCT", "ANY", "ALL"};

struct _Partial_ {
std::string to_string() const;

int64_t dim_{-1};
ReduceType type_{ReduceType::SUM};
};

class TensorDistAttr {
public:
TensorDistAttr() = default;

explicit TensorDistAttr(const std::vector<int64_t>& tensor_shape);

TensorDistAttr(const TensorDistAttr& tensor);
TensorDistAttr(const TensorDistAttr& dist_attr);

TensorDistAttr& operator=(const TensorDistAttr& dist_attr);

Expand All @@ -52,6 +72,29 @@ class TensorDistAttr {

void set_dims_mapping(const std::vector<int64_t>& dims_mapping);

// true if tensor is partial on any mesh dim.
bool is_partial() const { return !partial_status_.empty(); }

// return vector of mesh dims on which the this tensor is partial on
const std::vector<int64_t> partial_dims() const;

const paddle::flat_hash_map<int64_t, _Partial_>& partial_status() const {
return partial_status_;
}

// by map
void set_partial_status(
const paddle::flat_hash_map<int64_t, _Partial_>& partial_status);

// by each dim
void set_partial_status(const std::vector<int64_t>& dims,
const ReduceType& type = ReduceType::SUM);
// all
void clean_partial_status();

// clean by dims
void clean_partial_dims(const std::vector<int64_t>& dims);

void set_default_dims_mapping(const std::vector<int64_t>& tensor_shape);

int64_t batch_dim() const { return batch_dim_; }
Expand Down Expand Up @@ -89,11 +132,17 @@ class TensorDistAttr {

bool verify_annotated(const std::map<std::string, bool>& annotated) const;

bool verify_partial_status() const;

bool verify(const std::vector<int64_t>& tensor_shape) const;

// TensorDistAttr from_string(const std::string& dist_str);
std::string to_string() const;
std::string partial_status_string() const;

// in partial-support-stage-I partial will always be a runtime attribute,
// there is not need to serialize it. support the partial serialization in
// future partial-support-stage-II.
void from_proto(const TensorDistAttrProto& proto);

TensorDistAttrProto to_proto() const;
Expand All @@ -109,6 +158,10 @@ class TensorDistAttr {
int64_t batch_dim_{0};
std::vector<bool> dynamic_dims_;
std::map<std::string, bool> annotated_;
// partial map would be small (less than mesh.size)
// iterate operation (copy and comparision) would more frequency than random
// element access. <key: dim on mesh, value: partial object>
paddle::flat_hash_map<int64_t, _Partial_> partial_status_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a map structure here? If the "dim_" in Partial indicates the mesh dim, it seems unnecessary to store another mesh dim. In addition, if one tensor has only one reduce type, is it better to use a data structure like:
Partial {
vector<int64_t> mesh_dims;
ReduceType type_;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct !
firstly I thought we would use 『set』 for partial_status_, so build the Partial struct.
then I found『map』would be better for partial_status_, in most of use cases we use dim as key to retrieve Partial .

};

inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) {
Expand All @@ -122,6 +175,12 @@ inline bool operator!=(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
return !operator==(lhs, rhs);
}

bool operator==(const _Partial_& lhs, const _Partial_& rhs);

inline bool operator!=(const _Partial_& lhs, const _Partial_& rhs) {
return !operator==(lhs, rhs);
}

} // namespace auto_parallel
} // namespace distributed
} // namespace phi
Loading