Skip to content

Commit

Permalink
speed prune
Browse files Browse the repository at this point in the history
  • Loading branch information
Caozhou1995 committed Sep 14, 2023
1 parent 4fb4af0 commit a79e271
Showing 1 changed file with 117 additions and 17 deletions.
134 changes: 117 additions & 17 deletions python/paddle/distributed/auto_tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def divisor(num, reverse=False):

def dist_degree(mode, num_gpus, num_nodes, tuner_cfg=None):
"""Return the degree of different parallel modes by gpus and nodes num."""
assert mode in ["dp", "mp", "pp", "sharding"]
assert mode in ["dp", "mp", "pp", "sharding", "mbs", "vpp"]
results = []
prune_results = []
if mode == "dp":
results = divisor(num_gpus, reverse=False)

Expand All @@ -46,16 +47,63 @@ def dist_degree(mode, num_gpus, num_nodes, tuner_cfg=None):
results = list(range(num_nodes + 1, 0, -1))
else:
results = divisor(num_gpus, reverse=True)
for pp_degree in results:
prune_flag = False
num_layers = tuner_cfg["model_cfg"].get("num_layers", None)

if num_layers:
if num_layers % pp_degree != 0:
prune_flag = True

if not prune_flag:
prune_results.append(pp_degree)
results = prune_results

elif mode == "mp":
if tuner_cfg.get("enable_mp_prune", True):
gpus_per_node = num_gpus // num_nodes
results = divisor(gpus_per_node, reverse=True)
else:
results = divisor(num_gpus, reverse=True)
for mp_degree in results:
prune_flag = False
hidden_size = tuner_cfg["model_cfg"].get("hidden_size", None)
vocab_size = tuner_cfg["model_cfg"].get("vocab_size", None)
num_attention_heads = tuner_cfg["model_cfg"].get(
"num_attention_heads", None
)
seq_length = tuner_cfg["model_cfg"].get("seq_length", None)
use_sequence_paralel = tuner_cfg.get("use_sequence_paralel", False)

if hidden_size and hidden_size % mp_degree != 0:
prune_flag = True

if vocab_size and vocab_size % mp_degree != 0:
prune_flag = True

if num_attention_heads and num_attention_heads % mp_degree != 0:
prune_flag = True

if (
seq_length
and seq_length % mp_degree != 0
and use_sequence_paralel
):
prune_flag = True

if not prune_flag:
prune_results.append(mp_degree)
results = prune_results

elif mode == "sharding":
results = divisor(num_gpus, reverse=True)

elif mode == "mbs":
results = divisor(tuner_cfg["model_cfg"]["global_batch_size"])

elif mode == "vpp":
results = divisor(tuner_cfg["model_cfg"]["num_layers"], reverse=True)

return results


Expand Down Expand Up @@ -94,8 +142,8 @@ def default_candidates(tuner_cfg):
candidates["pp_degree"] = [1]

if tuner_cfg.get("vpp_degree", None) == "auto":
candidates["vpp_degree"] = list(
range(tuner_cfg["model_cfg"]["num_layers"], 0, -1)
candidates["vpp_degree"] = dist_degree(
"vpp", num_gpus, num_nodes, tuner_cfg
)
elif tuner_cfg.get("vpp_degree", None):
candidates["vpp_degree"] = tuner_cfg.get("vpp_degree")
Expand Down Expand Up @@ -135,8 +183,8 @@ def default_candidates(tuner_cfg):
candidates["recompute_granularity"] = [None]

if tuner_cfg.get("micro_batch_size", None) == "auto":
candidates["micro_batch_size"] = list(
range(1, tuner_cfg["model_cfg"]["global_batch_size"])
candidates["micro_batch_size"] = dist_degree(
"mbs", num_gpus, num_nodes, tuner_cfg
)
elif tuner_cfg.get("micro_batch_size", None):
candidates["micro_batch_size"] = tuner_cfg.get("micro_batch_size")
Expand All @@ -159,29 +207,81 @@ def search_all(tuner_cfg):
sharding_degree_candidates = candidates["sharding_degree"]
use_recompute_candidates = candidates["use_recompute"]
recompute_granularity_candidates = candidates["recompute_granularity"]
all_cfgs = list(

num_gpus = tuner_cfg["num_gpus"]
valid_degrees = []

for mp_degree in mp_degree_candidates:
degrees = []
if num_gpus % mp_degree != 0:
continue
degrees.append(mp_degree)
sharding_res = num_gpus // mp_degree

for sharding_degree in sharding_degree_candidates:
if sharding_res % sharding_degree != 0:
continue
degrees.append(sharding_degree)
pp_res = sharding_res // sharding_degree

for pp_degree in pp_degree_candidates:
if pp_res % pp_degree != 0:
continue
degrees.append(pp_degree)
dp_res = pp_res // pp_degree

for dp_degree in dp_degree_candidates:
if dp_res != dp_degree:
continue
degrees.append(dp_degree)
assert len(degrees) == 4
valid_degrees.append(copy.deepcopy(degrees))
degrees.pop()
degrees.pop()
degrees.pop()

other_dim_cfgs = list(
itertools.product(
mp_degree_candidates,
sharding_degree_candidates,
sharding_stage_candidates,
mbs_candidates,
pp_degree_candidates,
vpp_degree_candidates,
use_recompute_candidates,
recompute_granularity_candidates,
dp_degree_candidates,
)
)

all_cfgs = []
for valid_degree in valid_degrees:
for other_dim_cfg in other_dim_cfgs:
mp_degree, sharding_degree, pp_degree, dp_degree = valid_degree
(
sharding_stage,
mbs,
vpp,
use_recompute,
recompute_granularity,
) = list(other_dim_cfg)
if (
tuner_cfg["model_cfg"]["global_batch_size"]
% (mbs * sharding_degree * dp_degree)
!= 0
):
continue
if tuner_cfg["model_cfg"]["num_layers"] % (pp_degree * vpp) != 0:
continue
cfg = list(valid_degree) + list(other_dim_cfg)
all_cfgs.append(cfg)

mapping = {
0: "mp_degree",
1: "sharding_degree",
2: "sharding_stage",
3: "micro_batch_size",
4: "pp_degree",
5: "vpp_degree",
6: "use_recompute",
7: "recompute_granularity",
8: "dp_degree",
2: "pp_degree",
3: "dp_degree",
4: "sharding_stage",
5: "micro_batch_size",
6: "vpp_degree",
7: "use_recompute",
8: "recompute_granularity",
}
new_all_cfgs = []
for cfg in all_cfgs:
Expand Down

0 comments on commit a79e271

Please sign in to comment.