Skip to content

Commit

Permalink
[AutoTuner] Add cost model for autotuner (PaddlePaddle#58183)
Browse files Browse the repository at this point in the history
* add memory cost model

* support sharding stage prefix

* fix str type error
  • Loading branch information
Caozhou1995 authored Oct 19, 2023
1 parent 429e408 commit 576d02a
Show file tree
Hide file tree
Showing 5 changed files with 367 additions and 28 deletions.
143 changes: 143 additions & 0 deletions python/paddle/distributed/auto_tuner/cost_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def all_params(mp, pp, sharding, h, l, V):
# TODO: TBD - add some fixed structure models.
return 1


def full_recompute_acts(mp, pp, s, b, h, l):
# TODO: TBD - add some fixed structure models.
return 1


def all_acts(mp, pp, s, b, h, l, a):
# TODO: TBD - add some fixed structure models.
return 1


def to_gb(p):
return p / (2**30)


def get_mem(total_cards, parallel_cfg, l, h, a, V, s, gbs):
"""Estimate the memory of model unser parallel strategy."""
sharding = parallel_cfg["sharding_degree"]
mp = parallel_cfg["mp_degree"]
b = parallel_cfg["micro_batch_size"]
pp = parallel_cfg["pp_degree"]
vpp = parallel_cfg["vpp_degree"]
use_recompute = parallel_cfg["use_recompute"]

sep = 1

lbs = int(gbs / sharding / s)
lbs = int(lbs / pp) * pp
assert s % sep == 0
s_sep = s // sep
assert a % (sep * mp) == 0, f'{a} vs {sep * mp}'

vpp_ratio = 1
if vpp > 1:
assert l % (pp * vpp) == 0
vpp_ratio = 1 + (pp - 1) / (pp * vpp)

params = to_gb(all_params(mp, pp, sharding, h, l, V))

acts = 0
assert l % pp == 0

if use_recompute:
acts = to_gb(full_recompute_acts(mp, pp, s_sep, b, h, l)) * vpp_ratio
else:
acts = to_gb(all_acts(mp, pp, s, b, h, l, a)) * vpp_ratio
assert acts > 0

peak_mem = params + acts
return peak_mem


def divisor(num, reverse=False):
"""Get the divisor of a given number."""
results = set()
i = 1
mid = num // 2 + 1
while i < mid:
if num % i == 0:
results.add(i)
results.add(num // i)
i += 1
results = list(results)
return sorted(results, reverse=reverse)


def get_not_oom_cfgs(cfgs, tuner_cfg):
"""Get not OOM parallel strategies."""
total_cards, l, h, a, V, s, gbs, per_card_memory = (
tuner_cfg["estimated_num_gpus"],
tuner_cfg["model_cfg"]["num_layers"],
tuner_cfg["model_cfg"]["hidden_size"],
tuner_cfg["model_cfg"]["num_attention_heads"],
tuner_cfg["model_cfg"]["vocab_size"],
tuner_cfg["model_cfg"]["seq_length"],
tuner_cfg["model_cfg"]["global_batch_size"],
tuner_cfg.get("per_card_memory", 80),
)
pruned_cfgs = []
for cfg in cfgs:
mp = cfg["mp_degree"]
sharding = cfg["sharding_degree"]
mbs = cfg["micro_batch_size"]
pp = cfg["pp_degree"]
vpp = cfg["vpp_degree"]
dp = cfg["dp_degree"]
use_recompute = cfg["use_recompute"]

if mp * sharding * pp * dp != total_cards:
continue
if gbs % sharding != 0:
continue
if gbs // sharding % dp != 0:
continue
if gbs // sharding // dp % mbs != 0:
continue
if l % pp != 0:
continue
if l // pp % vpp != 0:
continue
if vpp != 1 and pp <= 2:
continue
if a % mp != 0 or V % mp != 0 or h % mp != 0:
continue

pruned_cfgs.append(cfg)
valid_cfgs = []
for cfg in pruned_cfgs:
mem = get_mem(total_cards, cfg, l, h, a, V, s, gbs)
# TODO: Uncomment when it is actually implemented.
# if (
# mem < per_card_memory
# and mem
# > tuner_cfg.get(
# "search_algo", {"name": "dp_estimation", "threshold": 0.7}
# ).get("threshold", 0.7)
# * per_card_memory
# ):
# cfg["memory_cost"] = mem
# valid_cfgs.append(cfg)
cfg["memory_cost"] = mem
valid_cfgs.append(cfg)
assert valid_cfgs
return valid_cfgs
30 changes: 29 additions & 1 deletion python/paddle/distributed/auto_tuner/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from abc import ABC, abstractmethod

from .prune import _PRUNE_FUNC
from .utils import gbs_search_all, search_all
from .utils import gbs_search_all, search_all, search_by_dp_estimation


class SearchAlgo(ABC):
Expand Down Expand Up @@ -54,6 +54,34 @@ def search_once(self, history_cfgs):
return new_cfg


class DpEstimationSearch(SearchAlgo):
def __init__(self, tuner_cfg):
super().__init__(tuner_cfg)
self.idx = 0
self.all_tasks = search_by_dp_estimation(tuner_cfg)
assert len(self.all_tasks) > 0, "Unable to perform this search."
# change global_batch_size and dp_degree
tuner_cfg["model_cfg"]["global_batch_size"] = (
tuner_cfg["model_cfg"]["global_batch_size"]
// self.all_tasks[0]["dp_degree"]
)
for task in self.all_tasks:
task["estimated_dp_degree"] = task["dp_degree"]
task["dp_degree"] = 1

def search_once(self, history_cfgs):
new_cfg = None
stop = False
while not stop:
if self.idx < len(self.all_tasks):
new_cfg = self.all_tasks[self.idx]
self.idx += 1
stop = not self.prune(self.tuner_cfg, new_cfg, history_cfgs)
else:
return None
return new_cfg


class GBSSearch(SearchAlgo):
def __init__(self, tuner_cfg):
super().__init__(tuner_cfg)
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/distributed/auto_tuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@ def __init__(self, tuner_cfg):
self.cur_task_id = 1
self.task_limit = tuner_cfg.get("task_limit", 100)

search_algo = tuner_cfg.get("search_algo", "grid")
search_algo = tuner_cfg.get("search_algo", {"name": "grid"})["name"]

if search_algo == "grid":
from .search import GridSearch

tuner_cfg["candidates"] = default_candidates(tuner_cfg)
self.algo = GridSearch(tuner_cfg)
elif search_algo == "dp_estimation":
from .search import DpEstimationSearch

tuner_cfg["candidates"] = default_candidates(tuner_cfg)
self.algo = DpEstimationSearch(tuner_cfg)
elif search_algo == "gbs":
from .search import GBSSearch

Expand Down
Loading

0 comments on commit 576d02a

Please sign in to comment.