Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTuner] support prune by memory cost model #59727

Merged
merged 7 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions python/paddle/distributed/auto_tuner/memory_cost_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argparse import ArgumentParser


def parse_arguments():
parser = ArgumentParser()

# for distributed strategy
parser.add_argument(
"--dp_degree", type=int, required=True, help="dp degree"
)
parser.add_argument(
"--mp_degree", type=int, required=True, help="mp degree"
)
parser.add_argument(
"--pp_degree", type=int, required=True, help="pp degree"
)
parser.add_argument(
"--vpp_degree", type=int, required=True, help="vpp degree"
)
parser.add_argument(
"--sharding_degree", type=int, required=True, help="sharding degree"
)
parser.add_argument(
"--sharding_stage", type=int, required=True, help="sharding stage"
)
parser.add_argument(
"--micro_batch_size", type=int, required=True, help="micro batch size"
)
parser.add_argument(
"--use_recompute", type=bool, required=True, help="use recompute"
)
parser.add_argument(
"--recompute_granularity",
type=str,
required=True,
choices=["None", "core_attn", "full_attn", "full"],
help="recompute granularity",
)

# for model config
parser.add_argument(
"--hidden_size", type=int, required=False, help="hidden size"
)
parser.add_argument(
"--num_attention_heads",
type=int,
required=False,
help="number of attention heads",
)
parser.add_argument(
"--num_layers", type=int, required=False, help="number of hidden layers"
)
parser.add_argument(
"--max_sequence_length",
type=int,
required=False,
help="maximum sequence length",
)
parser.add_argument(
"--vocab_size", type=int, required=False, help="vocabulary size"
)
parser.add_argument(
"--intermediate_size",
type=int,
required=False,
help="intermediate size",
)

return parser.parse_args()


def get_model_memory_usage(args):
# evaluate model memory usage based on distributed strategy and model setting
raise NotImplementedError(
"Please implement this function for memory usage estimation based on distributed strategy and model setting."
)


if __name__ == "__main__":
args = parse_arguments()
print(get_model_memory_usage(args))
47 changes: 45 additions & 2 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
import os
import subprocess

from paddle.distributed.launch.main import ctx

logger = logging.getLogger('auto_tuner')

_PRUNE_FUNC = []
_PRUNE_HISTORY_FUNC = []

Expand Down Expand Up @@ -491,6 +491,49 @@ def prune_by_num_gpus(tuner_cfg, cur_cfg, history_cfgs=[]):


@register_prune
def prune_by_memory_estimation(tuner_cfg, cur_cfg, history_cfgs=[]):
memory_estimation_tool = tuner_cfg.get("memory_estimation_tool", None)
# TODO(@gexiao): get from system api
max_memory_usage = tuner_cfg.get("max_mem_usage", None)
model_cfg = tuner_cfg["model_cfg"]

if memory_estimation_tool is None:
return False

if not os.path.exists(memory_estimation_tool):
raise ValueError(
f"memory_estimation_tool shoule be a valid path, but got {memory_estimation_tool}"
)

if max_memory_usage is None:
raise ValueError(
"max_mem_usage should be set when using memory estimation tool"
)

memory_estimation_cmd = f"python {memory_estimation_tool} --dp_degree {cur_cfg['dp_degree']} --mp_degree {cur_cfg['mp_degree']} \
--pp_degree {cur_cfg['pp_degree']} --vpp_degree {cur_cfg['vpp_degree']} \
--sharding_degree {cur_cfg['sharding_degree']} --sharding_stage {cur_cfg['sharding_stage']} \
--use_recompute {cur_cfg['use_recompute']} --micro_batch_size {cur_cfg['micro_batch_size']} \
--recompute_granularity {cur_cfg['recompute_granularity']} \
--hidden_size {model_cfg['hidden_size']} --num_attention_heads {model_cfg['num_attention_heads']} \
--num_layers {model_cfg['num_layers']} --max_sequence_length {model_cfg['max_sequence_length']} \
--vocab_size {model_cfg['vocab_size']} --intermediate_size {model_cfg['intermediate_size']} "
result = subprocess.run(
memory_estimation_cmd,
shell=True,
capture_output=True,
text=True,
)
if result.returncode == 0:
cur_memory_usage = round(float(result.stdout), 2)
cur_cfg["estimated_memory_usage"] = cur_memory_usage
return cur_memory_usage > max_memory_usage
else:
raise ValueError(
f"memory_estimation_tool failed with error: {result.stderr}"
)


def prune_by_sharding_overlap(tuner_cfg, cur_cfg, history_cfgs=[]):
"""Prune by sharding overlap for single dp estimation"""
if "sharding_overlap" in cur_cfg:
Expand Down