Skip to content

Commit

Permalink
Title: Fix setup_env_ranks to Properly Set Environment Variables Inst…
Browse files Browse the repository at this point in the history
…ead of Raising Error (deepspeedai#6979)

This pull request addresses an issue in setup_env_ranks where, under
certain conditions, the function raises an error instead of setting the
necessary MPI-related environment variables (LOCAL_RANK, RANK, and
WORLD_SIZE). The intended behavior is to properly map Open MPI variables
(OMPI_COMM_WORLD_*) to the variables expected by DeepSpeed/PyTorch, but
the code currently raises an EnvironmentError if these Open MPI
variables are not found.

With this fix, setup_env_ranks will:

- Correctly detect and map the required Open MPI environment variables.
- Only raise an error if there is genuinely no valid way to obtain rank
information from the environment (e.g., both Open MPI variables and any
fallback mechanism are unavailable).

Fix deepspeedai#6895

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
3 people authored and traincheck-team committed Feb 9, 2025
1 parent 90d674c commit 17c775f
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 5 deletions.
19 changes: 19 additions & 0 deletions deepspeed/launcher/multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,31 @@ def name(self):

def validate_args(self):
super().validate_args()

# Validate and set MPI environment variables
self._setup_mpi_environment()

#TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "":
raise ValueError(f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")

def _setup_mpi_environment(self):
"""Sets up MPI-related environment variables or raises an error if they're missing."""

required_vars = ['OMPI_COMM_WORLD_LOCAL_RANK', 'OMPI_COMM_WORLD_RANK', 'OMPI_COMM_WORLD_SIZE']

# Check if all these are present
if not all(var in os.environ for var in required_vars):
raise EnvironmentError("MPI environment variables are not set. "
"Ensure you are running the script with an MPI-compatible launcher.")

# Now safe to read all
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']

def get_cmd(self, environment, active_resources):
total_process_count = sum(self.resource_pool.values())

Expand Down
71 changes: 66 additions & 5 deletions tests/unit/launcher/test_multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def runner_info():
return env, hosts, world_info, args


@pytest.fixture
def mock_mpi_env(monkeypatch):
# Provide the 3 required MPI variables:
monkeypatch.setenv('OMPI_COMM_WORLD_LOCAL_RANK', '0')
monkeypatch.setenv('OMPI_COMM_WORLD_RANK', '0')
monkeypatch.setenv('OMPI_COMM_WORLD_SIZE', '1')


def test_pdsh_runner(runner_info):
env, resource_pool, world_info, args = runner_info
runner = mnrunner.PDSHRunner(args, world_info)
Expand All @@ -27,34 +35,87 @@ def test_pdsh_runner(runner_info):
assert env['PDSH_RCMD_TYPE'] == 'ssh'


def test_openmpi_runner(runner_info):
def test_openmpi_runner(runner_info, mock_mpi_env):
env, resource_pool, world_info, args = runner_info
runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
cmd = runner.get_cmd(env, resource_pool)
assert cmd[0] == 'mpirun'
assert 'eth0' in cmd


def test_btl_nic_openmpi_runner(runner_info):
def test_btl_nic_openmpi_runner(runner_info, mock_mpi_env):
env, resource_pool, world_info, _ = runner_info
args = parse_args(['--launcher_arg', '-mca btl_tcp_if_include eth1', 'test_launcher.py'])

runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
cmd = runner.get_cmd(env, resource_pool)
assert 'eth0' not in cmd
assert 'eth1' in cmd


def test_btl_nic_two_dashes_openmpi_runner(runner_info):
def test_btl_nic_two_dashes_openmpi_runner(runner_info, mock_mpi_env):
env, resource_pool, world_info, _ = runner_info
args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
cmd = runner.get_cmd(env, resource_pool)
assert 'eth0' not in cmd
assert 'eth1' in cmd


def test_setup_mpi_environment_success():
"""Test that _setup_mpi_environment correctly sets environment variables when MPI variables exist."""
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0'
os.environ['OMPI_COMM_WORLD_RANK'] = '1'
os.environ['OMPI_COMM_WORLD_SIZE'] = '2'

args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

runner = mnrunner.OpenMPIRunner(args, None, None)
# Set up the MPI environment
runner._setup_mpi_environment()

assert os.environ['LOCAL_RANK'] == '0'
assert os.environ['RANK'] == '1'
assert os.environ['WORLD_SIZE'] == '2'

# Clean up environment
del os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
del os.environ['OMPI_COMM_WORLD_RANK']
del os.environ['OMPI_COMM_WORLD_SIZE']
del os.environ['LOCAL_RANK']
del os.environ['RANK']
del os.environ['WORLD_SIZE']


def test_setup_mpi_environment_missing_variables():
"""Test that _setup_mpi_environment raises an EnvironmentError when MPI variables are missing."""

# Clear relevant environment variables
os.environ.pop('OMPI_COMM_WORLD_LOCAL_RANK', None)
os.environ.pop('OMPI_COMM_WORLD_RANK', None)
os.environ.pop('OMPI_COMM_WORLD_SIZE', None)

args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

with pytest.raises(EnvironmentError, match="MPI environment variables are not set"):
mnrunner.OpenMPIRunner(args, None, None)


def test_setup_mpi_environment_fail():
"""Test that _setup_mpi_environment fails if only partial MPI variables are provided."""
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0'
os.environ.pop('OMPI_COMM_WORLD_RANK', None) # missing variable
os.environ['OMPI_COMM_WORLD_SIZE'] = '2'

args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

with pytest.raises(EnvironmentError, match="MPI environment variables are not set"):
runner = mnrunner.OpenMPIRunner(args, None, None)

# Clean up environment
del os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
del os.environ['OMPI_COMM_WORLD_SIZE']


def test_mpich_runner(runner_info):
env, resource_pool, world_info, args = runner_info
runner = mnrunner.MPICHRunner(args, world_info, resource_pool)
Expand Down

0 comments on commit 17c775f

Please sign in to comment.