diff --git a/python/paddle/distributed/auto_tuner/memory_cost_model.py b/python/paddle/distributed/auto_tuner/memory_cost_model.py new file mode 100644 index 00000000000000..0aa9590faa7ec4 --- /dev/null +++ b/python/paddle/distributed/auto_tuner/memory_cost_model.py @@ -0,0 +1,32 @@ +from argparse import ArgumentParser + +def parse_arguments(): + parser = ArgumentParser() + + # for distributed strategy + parser.add_argument("dp_degree", type=int, help="dp degree") + parser.add_argument("pp_degree", type=int, help="pp degree") + parser.add_argument("mp_degree", type=int, help="mp degree") + parser.add_argument("sharding_degree", type=int, help="sharding degree") + parser.add_argument("sharding_stage", type=int, help="sharding stage") + parser.add_argument("micro_batch_size", type=int, help="micro batch size") + parser.add_argument("use_recompute", type=bool, help="use recompute") + parser.add_argument("recompute_granularity", type=int, help="recompute granularity") + + # for model config + parser.add_argument("hidden_size", type=int, help="hidden size") + parser.add_argument("num_attention_heads", type=int, help="number of attention heads") + parser.add_argument("num_hidden_layers", type=int, help="number of hidden layers") + parser.add_argument("max_sequence_length", type=int, help="maximum sequence length") + parser.add_argument("vocab_size", type=int, help="vocabulary size") + parser.add_argument("intermediate_size", type=int, help="intermediate size") + + return parser.parse_args() + +def get_model_memory_usage(args): + # evaluate model memory usage based on distributed strategy and model setting + raise NotImplementedError("Please implement this function for memory usage estimation based on distributed strategy and model setting.") + +if __name__ == "__main__": + args = parse_arguments() + print(get_model_memory_usage(args)) \ No newline at end of file diff --git a/python/paddle/distributed/auto_tuner/prune.py b/python/paddle/distributed/auto_tuner/prune.py index 976089f9d05f2b..3f1ad78076e3fe 100644 --- a/python/paddle/distributed/auto_tuner/prune.py +++ b/python/paddle/distributed/auto_tuner/prune.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import subprocess + _PRUNE_FUNC = [] @@ -353,3 +356,34 @@ def prune_by_num_gpus(tuner_cfg, cur_cfg, history_cfgs): return True return False + +@register_prune +def prune_by_memory_estimation(tuner_cfg, cur_cfg, history_cfgs): + memory_estimation_tool = tuner_cfg.get("memory_estimation_tool", None) + max_memory_usage = tuner_cfg.get("max_mem_usage", None) + model_cfg = tuner_cfg["model_cfg"] + + if memory_estimation_tool is None: + return False + + if not os.path.exists(memory_estimation_tool): + raise ValueError(f"memory_estimation_tool shoule be a valid path, but got {memory_estimation_tool}") + + if max_memory_usage is None: + raise ValueError(f"max_mem_usage should be set when using memory estimation tool") + + memory_estimation_cmd = f"python {memory_estimation_tool} --dp_degree {cur_cfg['dp_degree']} \ + --mp_degree {cur_cfg['mp_degree']} --pp_degree {cur_cfg['pp_degree']} \ + --sharding_degree {cur_cfg['sharding_degree']} --sharding_stage {cur_cfg['sharding_stage']} \ + --use_recompute {cur_cfg['use_recompute']} --micro_batch_size {cur_cfg['micro_batch_size']} \ + --recompute_granularity {cur_cfg['recompute_granularity']} \ + --hidden_size {model_cfg['hidden_size']} --num_attention_heads {model_cfg['num_attention_heads']} \ + --num_hidden_layers {model_cfg['num_hidden_layers']} \ + --max_sequence_length {model_cfg['max_sequence_length']} \ + --vocab_size {model_cfg['vocab_size']} --intermediate_size {model_cfg['intermediate_size']} " + result = subprocess.run(memory_estimation_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode == 0: + cur_memory_usage = float(result.stdout) + return cur_memory_usage > max_memory_usage + else: + raise ValueError(f"memory_estimation_tool failed with error: {result.stderr}")