Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel]Update AutoTuner #56939

Merged
merged 17 commits into from
Sep 25, 2023
Merged
86 changes: 83 additions & 3 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None):
mp_degree = cur_cfg.get("mp_degree", None)
hidden_size = tuner_cfg["model_cfg"].get("hidden_size", None)
vocab_size = tuner_cfg["model_cfg"].get("vocab_size", None)
num_attention_heads = tuner_cfg["model_cfg"].get(
"num_attention_heads", None
)
seq_length = tuner_cfg["model_cfg"].get("seq_length", None)
use_sequence_paralel = tuner_cfg.get("use_sequence_paralel", False)

if mp_degree is None:
return False
Expand All @@ -65,6 +70,12 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None):
if vocab_size and vocab_size % mp_degree != 0:
return True

if num_attention_heads and num_attention_heads % mp_degree != 0:
return True

if seq_length and seq_length % mp_degree != 0 and use_sequence_paralel:
return True

mp_degree_candidates = tuner_cfg.get("mp_degree", None)

if mp_degree_candidates == "auto":
Expand Down Expand Up @@ -112,6 +123,50 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None):
return False


@register_prune
def prune_by_vpp(tuner_cfg, cur_cfg, history_cfgs=None):
"""
Prune by vpp (virtual pipeline parallelism), the rules are:
1. VPP degree should be evenly divided by number of layers.
2. VPP degree should be in the candidates of user defined.
"""
pp_degree = cur_cfg.get("pp_degree", None)
vpp_degree = cur_cfg.get("vpp_degree", None)
num_layers = tuner_cfg["model_cfg"].get("num_layers", None)

if pp_degree is None:
return False

if vpp_degree is None:
return False

if num_layers:
if num_layers % (pp_degree * vpp_degree) != 0:
return True
if pp_degree == 1 and vpp_degree != 1:
return True
if pp_degree <= 2 and vpp_degree != 1:
return True

vpp_degree_candidates = tuner_cfg.get("vpp_degree", None)
if vpp_degree_candidates == "auto":
vpp_degree_candidates = tuner_cfg["candidates"]["vpp_degree"]
if vpp_degree_candidates:
if vpp_degree not in vpp_degree_candidates:
return True

cfgs = same_cfgs_beside("vpp_degree", cur_cfg, history_cfgs)
if cfgs:
for cfg in cfgs:
# memory prune
if (
cfg["vpp_degree"] > vpp_degree
and cfg.get("max_mem_usage") == "OOM"
):
return True
return False


@register_prune
def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None):
"""
Expand Down Expand Up @@ -144,6 +199,13 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None):
if local_batch_size:
if local_batch_size % micro_batch_size != 0:
return True
acc_steps = local_batch_size // micro_batch_size
vpp_degree = cur_cfg.get("vpp_degree", None)
if vpp_degree is not None and vpp_degree > 1:
pp_degree = cur_cfg.get("pp_degree", None)
if pp_degree is not None:
if acc_steps % pp_degree != 0:
return True

if mbs_candidates:
if micro_batch_size not in mbs_candidates:
Expand All @@ -158,6 +220,13 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None):
):
return True

# memory prune
if (
cfg["micro_batch_size"] < micro_batch_size
and cfg.get("max_mem_usage") == "OOM"
):
return True

return False


Expand Down Expand Up @@ -208,6 +277,13 @@ def prune_by_sharding(tuner_cfg, cur_cfg, history_cfgs):
):
return True

# memory prune
if (
cfg["sharding_stage"] > sharding_stage
and cfg.get("max_mem_usage") == "OOM"
):
return True

if sharding_degree == 1:
cfgs = same_cfgs_beside("sharding_stage", cur_cfg, history_cfgs)
if cfgs:
Expand Down Expand Up @@ -245,9 +321,6 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs):
if recompute_granularity not in recompute_granularity_candidates:
return True

if not use_recompute and recompute_granularity:
return True

cfgs = same_cfgs_beside("use_recompute", cur_cfg, history_cfgs)
if cfgs:
for cfg in cfgs:
Expand All @@ -258,6 +331,13 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs):
):
return True

if (
cfg["use_recompute"]
and not use_recompute
and cfg.get("max_mem_usage") == "OOM"
):
return True

if not use_recompute:
cfgs = same_cfgs_beside("recompute_granularity", cur_cfg, history_cfgs)
if cfgs:
Expand Down
26 changes: 25 additions & 1 deletion python/paddle/distributed/auto_tuner/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,34 @@ def sort_metric(self, direction, metric_name) -> None:
)
return

def get_best(self, metric, direction) -> Tuple[dict, bool]:
def get_best(self, metric, direction, mode=None) -> Tuple[dict, bool]:
self.sort_metric(direction=direction, metric_name=metric)
if len(self.history) == 0:
return (self.history[0], True)
if mode == "SFT" or mode == "LoRA":
best_cfg = self.history[0]
if (
isinstance(best_cfg["max_mem_usage"], str)
or best_cfg["time"] == -1
):
return (best_cfg, True)
first_few = 1
for cfg in self.history:
if (
not isinstance(cfg["max_mem_usage"], str)
and cfg["max_mem_usage"] < best_cfg["max_mem_usage"]
and cfg["time"] != -1
):
best_cfg = cfg
first_few += 1
if first_few >= 5:
break
return (best_cfg, False)
if (
isinstance(self.history[0]["max_mem_usage"], str)
or self.history[0]["time"] == -1
):
return (self.history[0], True)
return (self.history[0], False)

def store_history(self, path="./history.csv"):
Expand Down
Loading