From a1a9745592ede6ba967404b12c16e4938007d477 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Mon, 3 Mar 2025 23:18:19 -0800 Subject: [PATCH] Add NVTX ranges to categorize execution (#11945) * Add NVTX ranges to optimizer step Signed-off-by: Jaemin Choi * Use Tim's module Signed-off-by: Jaemin Choi * Fix NVTX functions import Signed-off-by: Jaemin Choi * Apply isort and black reformatting Signed-off-by: minitu Signed-off-by: Jaemin Choi * Apply isort and black reformatting Signed-off-by: minitu Signed-off-by: Jaemin Choi * Use NsysCallback and AppState Signed-off-by: Jaemin Choi * Apply isort and black reformatting Signed-off-by: minitu Signed-off-by: Jaemin Choi * Add nvtx_label Signed-off-by: Jaemin Choi * Add option to profile all ranks Signed-off-by: Jaemin Choi * Update NVTX label for MCore optimizer Signed-off-by: Jaemin Choi * Add NVTX range for data step Signed-off-by: Jaemin Choi * Apply isort and black reformatting Signed-off-by: minitu Signed-off-by: Jaemin Choi * Remove NVTX range for gpt_data_step Signed-off-by: Jaemin Choi * Cleanup Signed-off-by: Jaemin Choi * Use stack to keep track of NVTX ranges Signed-off-by: Jaemin Choi * Apply isort and black reformatting Signed-off-by: minitu Signed-off-by: Jaemin Choi * Apply isort and black reformatting Signed-off-by: minitu Signed-off-by: Jaemin Choi * Capitalize NVTX label Signed-off-by: Jaemin Choi * Fix linting failure Signed-off-by: Jaemin Choi --------- Signed-off-by: Jaemin Choi Signed-off-by: minitu Co-authored-by: Jaemin Choi Co-authored-by: minitu --- nemo/core/optim/mcore_optim.py | 8 +++ nemo/lightning/pytorch/callbacks/nsys.py | 8 ++- nemo/utils/app_state.py | 5 +- nemo/utils/nvtx.py | 63 ++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 nemo/utils/nvtx.py diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 7a0f89aaa31f..aea7a2f0f86b 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -14,6 +14,8 @@ import torch +from nemo.utils.nvtx import nvtx_range_pop, nvtx_range_push + def _filter_empty_common_step(state_dict): """ @@ -42,6 +44,8 @@ class McoreDistributedOptimizer(torch.optim.Optimizer): optim (MegatronOptimizer): The distributed optimizer from Megatron Core. """ + NVTX_LABEL = "nemo.core.optim.mcore_optim" + def __init__(self, optim): self.defaults = {} self.mcore_optimizer = optim @@ -121,10 +125,14 @@ def step(self, closure=None): # Apply closure loss = None if closure is not None: + nvtx_range_push(f"{McoreDistributedOptimizer.NVTX_LABEL}.step.closure") loss = closure() + nvtx_range_pop(f"{McoreDistributedOptimizer.NVTX_LABEL}.step.closure") # return unused update_successful, grad_norm, num_zeros_in_grad + nvtx_range_push(f"{McoreDistributedOptimizer.NVTX_LABEL}.step.step") _, grad_norm, num_zeros_in_grad = self.mcore_optimizer.step() + nvtx_range_pop(f"{McoreDistributedOptimizer.NVTX_LABEL}.step.step") return loss, grad_norm, num_zeros_in_grad diff --git a/nemo/lightning/pytorch/callbacks/nsys.py b/nemo/lightning/pytorch/callbacks/nsys.py index 0368b2d52773..9b1d02f42e0e 100644 --- a/nemo/lightning/pytorch/callbacks/nsys.py +++ b/nemo/lightning/pytorch/callbacks/nsys.py @@ -18,6 +18,7 @@ from lightning.pytorch.callbacks.callback import Callback from nemo.utils import logging +from nemo.utils.app_state import AppState from nemo.utils.get_rank import get_rank @@ -48,9 +49,10 @@ class NsysCallback(Callback): end_step (int): Global batch to end profiling ranks (List[int]): Global rank IDs to profile gen_shape (bool): Generate model and kernel details including input shapes + nvtx_ranges (bool): Insert NVTX ranges to categorize execution Example: - >>> callback = NsysCallback(start_step=100, end_step=200, ranks=[0, 1], gen_shape=True) + >>> callback = NsysCallback(start_step=100, end_step=200, ranks=[0, 1], gen_shape=True, nvtx_ranges=False) >>> trainer = Trainer(callbacks=[callback]) """ @@ -60,6 +62,7 @@ def __init__( end_step: int, ranks: List[int] = [0], gen_shape: bool = False, + nvtx_ranges: bool = False, ): assert type(start_step) is int, f'Nsys start_step must be of type int. Found: {type(start_step)}' self._nsys_profile_start_step = start_step @@ -74,6 +77,9 @@ def __init__( self._nsys_profile_ranks = ranks self._nsys_profile_gen_shape = gen_shape + app_state = AppState() + app_state._nvtx_ranges = nvtx_ranges + logging.info( f'Nsys profiling setup with start_step: {self._nsys_profile_start_step},' f'and end_step: {self._nsys_profile_end_step}' diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index 8a83725f74a4..1bfab4b726ca 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from threading import Lock -from typing import Dict, Optional +from typing import Optional from nemo.utils.metaclasses import Singleton @@ -93,6 +93,9 @@ def __init__(self): # command-ling arguments for run self._cmd_args = None + # Insert NVTX ranges to categorize execution + self._nvtx_ranges = False + @property def device_id(self): """Property returns the device_id diff --git a/nemo/utils/nvtx.py b/nemo/utils/nvtx.py new file mode 100644 index 000000000000..7f375446c5c9 --- /dev/null +++ b/nemo/utils/nvtx.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Optional + +import torch + +from nemo.utils.app_state import AppState + +# pylint: disable=C0116 + + +@functools.lru_cache(maxsize=None) +def _nvtx_enabled() -> bool: + """Check if NVTX range profiling is enabled""" + return AppState()._nvtx_ranges + + +# Messages associated with active NVTX ranges +_nvtx_range_messages: list[str] = [] + + +def nvtx_range_push(msg: str) -> None: + # Return immediately if NVTX range profiling is not enabled + if not _nvtx_enabled(): + return + + # Push NVTX range to stack + _nvtx_range_messages.append(msg) + torch.cuda.nvtx.range_push(msg) + + +def nvtx_range_pop(msg: Optional[str] = None) -> None: + # Return immediately if NVTX range profiling is not enabled + if not _nvtx_enabled(): + return + + # Update list of NVTX range messages and check for consistency + if not _nvtx_range_messages: + raise RuntimeError("Attempted to pop NVTX range from empty stack") + last_msg = _nvtx_range_messages.pop() + if msg is not None and msg != last_msg: + raise ValueError( + f"Attempted to pop NVTX range from stack with msg={msg}, " f"but last range has msg={last_msg}" + ) + + # Pop NVTX range + torch.cuda.nvtx.range_pop() + + +# pylint: enable=C0116