Skip to content

Commit

Permalink
[AutoTuner] support prune by memory cost model (#59727)
Browse files Browse the repository at this point in the history
* [AutoTuner] support prune by memory cost model

* add print check

* fix parser

* fix parser

* add estimated_memory_usage to cur_fig for save usage
  • Loading branch information
SylarTiaNII authored Dec 9, 2023
1 parent c0c1ac2 commit 948ec01
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 2 deletions.
95 changes: 95 additions & 0 deletions python/paddle/distributed/auto_tuner/memory_cost_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argparse import ArgumentParser


def parse_arguments():
parser = ArgumentParser()

# for distributed strategy
parser.add_argument(
"--dp_degree", type=int, required=True, help="dp degree"
)
parser.add_argument(
"--mp_degree", type=int, required=True, help="mp degree"
)
parser.add_argument(
"--pp_degree", type=int, required=True, help="pp degree"
)
parser.add_argument(
"--vpp_degree", type=int, required=True, help="vpp degree"
)
parser.add_argument(
"--sharding_degree", type=int, required=True, help="sharding degree"
)
parser.add_argument(
"--sharding_stage", type=int, required=True, help="sharding stage"
)
parser.add_argument(
"--micro_batch_size", type=int, required=True, help="micro batch size"
)
parser.add_argument(
"--use_recompute", type=bool, required=True, help="use recompute"
)
parser.add_argument(
"--recompute_granularity",
type=str,
required=True,
choices=["None", "core_attn", "full_attn", "full"],
help="recompute granularity",
)

# for model config
parser.add_argument(
"--hidden_size", type=int, required=False, help="hidden size"
)
parser.add_argument(
"--num_attention_heads",
type=int,
required=False,
help="number of attention heads",
)
parser.add_argument(
"--num_layers", type=int, required=False, help="number of hidden layers"
)
parser.add_argument(
"--max_sequence_length",
type=int,
required=False,
help="maximum sequence length",
)
parser.add_argument(
"--vocab_size", type=int, required=False, help="vocabulary size"
)
parser.add_argument(
"--intermediate_size",
type=int,
required=False,
help="intermediate size",
)

return parser.parse_args()


def get_model_memory_usage(args):
# evaluate model memory usage based on distributed strategy and model setting
raise NotImplementedError(
"Please implement this function for memory usage estimation based on distributed strategy and model setting."
)


if __name__ == "__main__":
args = parse_arguments()
print(get_model_memory_usage(args))
47 changes: 45 additions & 2 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
import os
import subprocess

from paddle.distributed.launch.main import ctx

logger = logging.getLogger('auto_tuner')

_PRUNE_FUNC = []
_PRUNE_HISTORY_FUNC = []

Expand Down Expand Up @@ -491,6 +491,49 @@ def prune_by_num_gpus(tuner_cfg, cur_cfg, history_cfgs=[]):


@register_prune
def prune_by_memory_estimation(tuner_cfg, cur_cfg, history_cfgs=[]):
memory_estimation_tool = tuner_cfg.get("memory_estimation_tool", None)
# TODO(@gexiao): get from system api
max_memory_usage = tuner_cfg.get("max_mem_usage", None)
model_cfg = tuner_cfg["model_cfg"]

if memory_estimation_tool is None:
return False

if not os.path.exists(memory_estimation_tool):
raise ValueError(
f"memory_estimation_tool shoule be a valid path, but got {memory_estimation_tool}"
)

if max_memory_usage is None:
raise ValueError(
"max_mem_usage should be set when using memory estimation tool"
)

memory_estimation_cmd = f"python {memory_estimation_tool} --dp_degree {cur_cfg['dp_degree']} --mp_degree {cur_cfg['mp_degree']} \
--pp_degree {cur_cfg['pp_degree']} --vpp_degree {cur_cfg['vpp_degree']} \
--sharding_degree {cur_cfg['sharding_degree']} --sharding_stage {cur_cfg['sharding_stage']} \
--use_recompute {cur_cfg['use_recompute']} --micro_batch_size {cur_cfg['micro_batch_size']} \
--recompute_granularity {cur_cfg['recompute_granularity']} \
--hidden_size {model_cfg['hidden_size']} --num_attention_heads {model_cfg['num_attention_heads']} \
--num_layers {model_cfg['num_layers']} --max_sequence_length {model_cfg['max_sequence_length']} \
--vocab_size {model_cfg['vocab_size']} --intermediate_size {model_cfg['intermediate_size']} "
result = subprocess.run(
memory_estimation_cmd,
shell=True,
capture_output=True,
text=True,
)
if result.returncode == 0:
cur_memory_usage = round(float(result.stdout), 2)
cur_cfg["estimated_memory_usage"] = cur_memory_usage
return cur_memory_usage > max_memory_usage
else:
raise ValueError(
f"memory_estimation_tool failed with error: {result.stderr}"
)


def prune_by_sharding_overlap(tuner_cfg, cur_cfg, history_cfgs=[]):
"""Prune by sharding overlap for single dp estimation"""
if "sharding_overlap" in cur_cfg:
Expand Down

0 comments on commit 948ec01

Please sign in to comment.