Skip to content

Commit

Permalink
refactor user arg parsing and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii committed Dec 15, 2023
1 parent bc1b5a6 commit 3da3fbc
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 7 deletions.
14 changes: 11 additions & 3 deletions deepspeed/launcher/multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,21 @@ def __init__(self, args, world_info_base64):
def backend_exists(self):
return shutil.which('pdsh')

def parse_user_args(self):
processed_args = []
for arg in self.args.user_args:
# With pdsh, if we are passing a string as an argument, it will get
# split on whitespace. To avoid this and support strings that
# contain '"', we do this extra processing step:
if " " in arg:
arg = '"{}"'.format(arg.replace('"', '\\"'))
processed_args.append(arg)
return processed_args

@property
def name(self):
return "pdsh"

def parse_user_args(self):
return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args))

def get_cmd(self, environment, active_resources):
environment['PDSH_RCMD_TYPE'] = 'ssh'
if self.args.ssh_port is not None: # only specify ssh port if it is specified
Expand Down
4 changes: 0 additions & 4 deletions deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import os
import re
import sys
import shlex
import json
import base64
import argparse
Expand Down Expand Up @@ -389,9 +388,6 @@ def parse_num_nodes(str_num_nodes: str, elastic_training: bool):
def main(args=None):
args = parse_args(args)

# For when argparse interprets remaining args as a single string
args.user_args = shlex.split(" ".join(list(map(lambda x: x if x.startswith("-") else f'"{x}"', args.user_args))))

if args.elastic_training:
assert args.master_addr != "", "Master Addr is required when elastic training is enabled"

Expand Down
64 changes: 64 additions & 0 deletions tests/unit/launcher/test_user_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import subprocess

from deepspeed.accelerator import get_accelerator

if not get_accelerator().is_available():
pytest.skip("only supported in accelerator environments.", allow_module_level=True)

user_arg_test_script = """import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--world_size", type=int, default=1)
args = parser.parse_args()
print("ARG PARSE SUCCESS")
"""


@pytest.fixture(scope="function")
def user_script_fp(tmpdir):
script_fp = tmpdir.join("user_arg_test.py")
with open(script_fp, "w") as f:
f.write(user_arg_test_script)
return script_fp


@pytest.fixture(scope="function")
def cmd(user_script_fp, prompt, multi_node):
if multi_node:
cmd = ("deepspeed", "--force_multi", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt)
else:
cmd = ("deepspeed", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt)
return cmd


@pytest.mark.parametrize("prompt", [
'''"I am 6' tall"''', """'I am 72" tall'""", """'"translate English to Romanian: "'""",
'''I'm going to tell them "DeepSpeed is the best"'''
])
@pytest.mark.parametrize("multi_node", [True, False])
def test_user_args(cmd):
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = p.communicate()
assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}"


def test_bash_string_args(tmpdir, user_script_fp):
bash_script = f"""
ARGS="--prompt 'DeepSpeed is the best'"
echo ${{ARGS}}|xargs deepspeed --num_nodes 1 --num_gpus 1 {user_script_fp}
"""

bash_fp = tmpdir.join("bash_script.sh")
with open(bash_fp, "w") as f:
f.write(bash_script)

p = subprocess.Popen(["bash", bash_fp], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = p.communicate()
assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}"

0 comments on commit 3da3fbc

Please sign in to comment.