Skip to content

Commit

Permalink
[Auto Parallel]Update AutoTuner (#56939)
Browse files Browse the repository at this point in the history
* add vpp search

* support cmd json or yaml

* add auto tuner log

* add OOM prune

* support sft and lora best cfg

* support sp

* fix sft/lora cfg

* fix json/yaml bug and update multi nodes status

* add pp and mp prune flag

* update log and csv path

* merge dev

* speed prune

* add search stage and run best stage cmd

* fix recompute_granularity prune bug

* fix get best cfg in sft and lora mode bug

* fix single card bug

* update read metric
  • Loading branch information
Caozhou1995 authored Sep 25, 2023
1 parent 589f0f2 commit 9c3bffb
Show file tree
Hide file tree
Showing 6 changed files with 598 additions and 171 deletions.
86 changes: 83 additions & 3 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None):
mp_degree = cur_cfg.get("mp_degree", None)
hidden_size = tuner_cfg["model_cfg"].get("hidden_size", None)
vocab_size = tuner_cfg["model_cfg"].get("vocab_size", None)
num_attention_heads = tuner_cfg["model_cfg"].get(
"num_attention_heads", None
)
seq_length = tuner_cfg["model_cfg"].get("seq_length", None)
use_sequence_paralel = tuner_cfg.get("use_sequence_paralel", False)

if mp_degree is None:
return False
Expand All @@ -65,6 +70,12 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None):
if vocab_size and vocab_size % mp_degree != 0:
return True

if num_attention_heads and num_attention_heads % mp_degree != 0:
return True

if seq_length and seq_length % mp_degree != 0 and use_sequence_paralel:
return True

mp_degree_candidates = tuner_cfg.get("mp_degree", None)

if mp_degree_candidates == "auto":
Expand Down Expand Up @@ -112,6 +123,50 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None):
return False


@register_prune
def prune_by_vpp(tuner_cfg, cur_cfg, history_cfgs=None):
"""
Prune by vpp (virtual pipeline parallelism), the rules are:
1. VPP degree should be evenly divided by number of layers.
2. VPP degree should be in the candidates of user defined.
"""
pp_degree = cur_cfg.get("pp_degree", None)
vpp_degree = cur_cfg.get("vpp_degree", None)
num_layers = tuner_cfg["model_cfg"].get("num_layers", None)

if pp_degree is None:
return False

if vpp_degree is None:
return False

if num_layers:
if num_layers % (pp_degree * vpp_degree) != 0:
return True
if pp_degree == 1 and vpp_degree != 1:
return True
if pp_degree <= 2 and vpp_degree != 1:
return True

vpp_degree_candidates = tuner_cfg.get("vpp_degree", None)
if vpp_degree_candidates == "auto":
vpp_degree_candidates = tuner_cfg["candidates"]["vpp_degree"]
if vpp_degree_candidates:
if vpp_degree not in vpp_degree_candidates:
return True

cfgs = same_cfgs_beside("vpp_degree", cur_cfg, history_cfgs)
if cfgs:
for cfg in cfgs:
# memory prune
if (
cfg["vpp_degree"] > vpp_degree
and cfg.get("max_mem_usage") == "OOM"
):
return True
return False


@register_prune
def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None):
"""
Expand Down Expand Up @@ -144,6 +199,13 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None):
if local_batch_size:
if local_batch_size % micro_batch_size != 0:
return True
acc_steps = local_batch_size // micro_batch_size
vpp_degree = cur_cfg.get("vpp_degree", None)
if vpp_degree is not None and vpp_degree > 1:
pp_degree = cur_cfg.get("pp_degree", None)
if pp_degree is not None:
if acc_steps % pp_degree != 0:
return True

if mbs_candidates:
if micro_batch_size not in mbs_candidates:
Expand All @@ -158,6 +220,13 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None):
):
return True

# memory prune
if (
cfg["micro_batch_size"] < micro_batch_size
and cfg.get("max_mem_usage") == "OOM"
):
return True

return False


Expand Down Expand Up @@ -208,6 +277,13 @@ def prune_by_sharding(tuner_cfg, cur_cfg, history_cfgs):
):
return True

# memory prune
if (
cfg["sharding_stage"] > sharding_stage
and cfg.get("max_mem_usage") == "OOM"
):
return True

if sharding_degree == 1:
cfgs = same_cfgs_beside("sharding_stage", cur_cfg, history_cfgs)
if cfgs:
Expand Down Expand Up @@ -245,9 +321,6 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs):
if recompute_granularity not in recompute_granularity_candidates:
return True

if not use_recompute and recompute_granularity:
return True

cfgs = same_cfgs_beside("use_recompute", cur_cfg, history_cfgs)
if cfgs:
for cfg in cfgs:
Expand All @@ -258,6 +331,13 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs):
):
return True

if (
cfg["use_recompute"]
and not use_recompute
and cfg.get("max_mem_usage") == "OOM"
):
return True

if not use_recompute:
cfgs = same_cfgs_beside("recompute_granularity", cur_cfg, history_cfgs)
if cfgs:
Expand Down
26 changes: 25 additions & 1 deletion python/paddle/distributed/auto_tuner/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,34 @@ def sort_metric(self, direction, metric_name) -> None:
reverse=False,
)

def get_best(self, metric, direction) -> Tuple[dict, bool]:
def get_best(self, metric, direction, mode=None) -> Tuple[dict, bool]:
self.sort_metric(direction=direction, metric_name=metric)
if len(self.history) == 0:
return (self.history[0], True)
if mode == "SFT" or mode == "LoRA":
best_cfg = self.history[0]
if (
isinstance(best_cfg["max_mem_usage"], str)
or best_cfg["time"] == -1
):
return (best_cfg, True)
first_few = 1
for cfg in self.history:
if (
not isinstance(cfg["max_mem_usage"], str)
and cfg["max_mem_usage"] < best_cfg["max_mem_usage"]
and cfg["time"] != -1
):
best_cfg = cfg
first_few += 1
if first_few >= 5:
break
return (best_cfg, False)
if (
isinstance(self.history[0]["max_mem_usage"], str)
or self.history[0]["time"] == -1
):
return (self.history[0], True)
return (self.history[0], False)

def store_history(self, path="./history.csv"):
Expand Down
Loading

0 comments on commit 9c3bffb

Please sign in to comment.