Skip to content

Commit

Permalink
Add StragglerDetection and auto-relaunch (#11328)
Browse files Browse the repository at this point in the history
Signed-off-by: Shriya Palsamudram <[email protected]>
  • Loading branch information
ShriyaPalsamudram authored Nov 19, 2024
1 parent 1c59ce3 commit 4b93e7f
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 1 deletion.
13 changes: 13 additions & 0 deletions nemo/collections/llm/recipes/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
53 changes: 53 additions & 0 deletions nemo/collections/llm/recipes/callbacks/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from nemo_run import Config, cli

from nemo.utils.import_utils import safe_import

res_module, HAVE_RES = safe_import('nvidia_resiliency_ext.ptl_resiliency')


@cli.factory(is_target_default=True)
def straggler_det_callback(
straggler_report_time_interval: Optional[int] = 300, stop_if_detected_straggler: Optional[bool] = True
) -> Config[res_module.StragglerDetectionCallback]:
"""
This callback is used to detect slower ranks participating in a PyTorch distributed workload.
This callback is obtained from nvidia-resiliency-ext.
Performance scores are scalar values from 0.0 (worst) to 1.0 (best), reflecting each rank's performance.
A performance score can be interpreted as the ratio of current performance to reference performance.
Depending on the reference used, there are two types of performance scores:
Relative performance score: The best-performing rank in the workload is used as a reference.
Individual performance score: The best historical performance of the rank is used as a reference.
If the performance score drops below the threshold which is set to 0.7, it is deemed as a straggler.
To detect the stragglers, users can enable this callback which reports the performance scores every 5mins.
Args:
straggler_report_time_interval (int): Performance score reporting frequency in seconds, Default is 300 seconds.
stop_if_detected_straggler (bool): Whether to stop training if a straggler is detection. Default is True.
"""

return Config(
res_module.StragglerDetectionCallback,
report_time_interval=straggler_report_time_interval,
calc_relative_gpu_perf=True,
calc_individual_gpu_perf=True,
num_gpu_perf_scores_to_print=5,
gpu_relative_perf_threshold=0.7,
gpu_individual_perf_threshold=0.7,
stop_if_detected=stop_if_detected_straggler,
enable_ptl_logging=True,
)
56 changes: 55 additions & 1 deletion nemo/lightning/run/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
from nemo.utils import logging

from nemo.utils.import_utils import safe_import

res_module, HAVE_RES = safe_import('nvidia_resiliency_ext.ptl_resiliency')

# This file contains plugins based on NeMo-Run's run.Plugin API.
# Plugins operate both on a configured task and an executor at the same time, and are specific to NeMo-Run.
# If you are adding functionality that goes directly into the Pytorch Lightning trainer, you may consider adding a callback instead of a plugin.
# If you are adding functionality that goes directly into the Pytorch Lightning trainer,
# you may consider adding a callback instead of a plugin.


def _merge_callbacks(partial: run.Partial, callbacks: list[run.Config[Callback]]):
Expand Down Expand Up @@ -79,6 +84,55 @@ def setup(self, task: run.Partial | run.Script, executor: run.Executor):
_merge_callbacks(task, callbacks=self.callbacks)


@dataclass(kw_only=True)
class FaultTolerancePlugin(run.Plugin):
"""
A plugin for setting up the fault tolerance callback from nvidia-resiliency-ext.
This plugin enables workload hang detection, automatic calculation of timeouts used for hang detection, detection of rank(s) terminated due to an error and workload respawning in case of a failure.
Note: FaultTolerancePlugin does not work with the NsysPlugin.
Args:
num_in_process_restarts (int): Max number of restarts on failure, within the same job. Default is 3.
num_job_retries_on_failure (int): Max number of new job restarts on failure. Default is 2.
initial_rank_heartbeat_timeout (int): Timeouts are time intervals used by a rank monitor to detect that a rank is not alive. This is the max timeout for the initial heartbeat. Default is 1800.
rank_heartbeat_timeout (int): This is the timeout for subsequent hearbeats after the initial heartbeat. Default is 300.
"""

num_in_process_restarts: int = 3
num_job_retries_on_failure: int = 2
initial_rank_heartbeat_timeout: int = 1800
rank_heartbeat_timeout: int = 300

def setup(self, task: run.Partial | run.Script, executor: run.Executor):

assert HAVE_RES, "nvidia-resiliency-ext.ptl_resiliency is required to use the FaultTolerancePlugin."

executor.launcher = run.FaultTolerance(
max_restarts=self.num_in_process_restarts,
initial_rank_heartbeat_timeout=self.initial_rank_heartbeat_timeout,
rank_heartbeat_timeout=self.rank_heartbeat_timeout,
)
executor.retries = self.num_job_retries_on_failure

assert isinstance(task, run.Partial)

callbacks = [
run.Config(
res_module.FaultToleranceCallback, autoresume=True, calculate_timeouts=True, exp_dir=task.log.log_dir
)
]

assert not executor.launcher.nsys_profile, "Nsys not supported with the FaultTolerancePlugin."
if hasattr(task, "trainer") and hasattr(task.trainer, "callbacks"):
assert all(
map(
lambda cb: not cb.__fn_or_cls__ == NsysCallback if "__fn_or_cls__" in dir(cb) else True,
task.trainer.callbacks,
)
), "Nsys not supported with FaultTolerancePlugin."

_merge_callbacks(task, callbacks=callbacks)


@dataclass(kw_only=True)
class NsysPlugin(run.Plugin):
"""
Expand Down

0 comments on commit 4b93e7f

Please sign in to comment.