Skip to content

Commit

Permalink
[AutoTuner] support prune by memory cost model
Browse files Browse the repository at this point in the history
  • Loading branch information
SylarTiaNII committed Dec 5, 2023
1 parent b960cec commit 6cd6ac2
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
32 changes: 32 additions & 0 deletions python/paddle/distributed/auto_tuner/memory_cost_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from argparse import ArgumentParser

def parse_arguments():
parser = ArgumentParser()

# for distributed strategy
parser.add_argument("dp_degree", type=int, help="dp degree")
parser.add_argument("pp_degree", type=int, help="pp degree")
parser.add_argument("mp_degree", type=int, help="mp degree")
parser.add_argument("sharding_degree", type=int, help="sharding degree")
parser.add_argument("sharding_stage", type=int, help="sharding stage")
parser.add_argument("micro_batch_size", type=int, help="micro batch size")
parser.add_argument("use_recompute", type=bool, help="use recompute")
parser.add_argument("recompute_granularity", type=int, help="recompute granularity")

# for model config
parser.add_argument("hidden_size", type=int, help="hidden size")
parser.add_argument("num_attention_heads", type=int, help="number of attention heads")
parser.add_argument("num_hidden_layers", type=int, help="number of hidden layers")
parser.add_argument("max_sequence_length", type=int, help="maximum sequence length")
parser.add_argument("vocab_size", type=int, help="vocabulary size")
parser.add_argument("intermediate_size", type=int, help="intermediate size")

return parser.parse_args()

def get_model_memory_usage(args):
# evaluate model memory usage based on distributed strategy and model setting
raise NotImplementedError("Please implement this function for memory usage estimation based on distributed strategy and model setting.")

if __name__ == "__main__":
args = parse_arguments()
print(get_model_memory_usage(args))
34 changes: 34 additions & 0 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess

_PRUNE_FUNC = []


Expand Down Expand Up @@ -353,3 +356,34 @@ def prune_by_num_gpus(tuner_cfg, cur_cfg, history_cfgs):
return True

return False

@register_prune
def prune_by_memory_estimation(tuner_cfg, cur_cfg, history_cfgs):
memory_estimation_tool = tuner_cfg.get("memory_estimation_tool", None)
max_memory_usage = tuner_cfg.get("max_mem_usage", None)
model_cfg = tuner_cfg["model_cfg"]

if memory_estimation_tool is None:
return False

if not os.path.exists(memory_estimation_tool):
raise ValueError(f"memory_estimation_tool shoule be a valid path, but got {memory_estimation_tool}")

if max_memory_usage is None:
raise ValueError(f"max_mem_usage should be set when using memory estimation tool")

memory_estimation_cmd = f"python {memory_estimation_tool} --dp_degree {cur_cfg['dp_degree']} \
--mp_degree {cur_cfg['mp_degree']} --pp_degree {cur_cfg['pp_degree']} \
--sharding_degree {cur_cfg['sharding_degree']} --sharding_stage {cur_cfg['sharding_stage']} \
--use_recompute {cur_cfg['use_recompute']} --micro_batch_size {cur_cfg['micro_batch_size']} \
--recompute_granularity {cur_cfg['recompute_granularity']} \
--hidden_size {model_cfg['hidden_size']} --num_attention_heads {model_cfg['num_attention_heads']} \
--num_hidden_layers {model_cfg['num_hidden_layers']} \
--max_sequence_length {model_cfg['max_sequence_length']} \
--vocab_size {model_cfg['vocab_size']} --intermediate_size {model_cfg['intermediate_size']} "
result = subprocess.run(memory_estimation_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode == 0:
cur_memory_usage = float(result.stdout)
return cur_memory_usage > max_memory_usage
else:
raise ValueError(f"memory_estimation_tool failed with error: {result.stderr}")

0 comments on commit 6cd6ac2

Please sign in to comment.