Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Fixup arch checks for ROCM #2627

Merged
merged 4 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ RUN echo "Base image is $BASE_IMAGE"
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"

# this does not always work for all rocm versions
RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"

ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
Expand Down
90 changes: 56 additions & 34 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Supported NVIDIA GPU architectures.
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)


Expand Down Expand Up @@ -63,22 +63,6 @@ def _is_cuda() -> bool:
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]


def get_amdgpu_offload_arch():
command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
try:
output = subprocess.check_output([command])
return output.decode('utf-8').strip()
except subprocess.CalledProcessError as e:
error_message = f"Error: {e}"
raise RuntimeError(error_message) from e
except FileNotFoundError as e:
# If the command is not found, print an error message
error_message = f"The command {command} was not found."
raise RuntimeError(error_message) from e

return None


def get_hipcc_rocm_version():
# Run the hipcc --version command
result = subprocess.run(['hipcc', '--version'],
Expand Down Expand Up @@ -138,6 +122,50 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
return nvcc_cuda_version


def get_pytorch_rocm_arch() -> Set[str]:
"""Get the cross section of Pytorch,and vllm supported gfx arches

ROCM can get the supported gfx architectures in one of two ways
Either through the PYTORCH_ROCM_ARCH env var, or output from
rocm_agent_enumerator.

In either case we can generate a list of supported arch's and
cross reference with VLLM's own ROCM_SUPPORTED_ARCHs.
"""
env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None)

# If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator
if env_arch_list is None:
command = "rocm_agent_enumerator"
env_arch_list = subprocess.check_output([command]).decode('utf-8')\
.strip().replace("\n", ";")
arch_source_str = "rocm_agent_enumerator"
else:
arch_source_str = "PYTORCH_ROCM_ARCH env variable"

# List are separated by ; or space.
pytorch_rocm_arch = set(env_arch_list.replace(" ", ";").split(";"))

# Filter out the invalid architectures and print a warning.
arch_list = pytorch_rocm_arch.intersection(ROCM_SUPPORTED_ARCHS)

# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
f"None of the ROCM architectures in {arch_source_str} "
f"({env_arch_list}) is supported. "
f"Supported ROCM architectures are: {ROCM_SUPPORTED_ARCHS}.")
invalid_arch_list = pytorch_rocm_arch - ROCM_SUPPORTED_ARCHS
if invalid_arch_list:
warnings.warn(
f"Unsupported ROCM architectures ({invalid_arch_list}) are "
f"excluded from the {arch_source_str} output "
f"({env_arch_list}). Supported ROCM architectures are: "
f"{ROCM_SUPPORTED_ARCHS}.",
stacklevel=2)
return arch_list


def get_torch_arch_list() -> Set[str]:
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
Expand All @@ -162,22 +190,27 @@ def get_torch_arch_list() -> Set[str]:
# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. "
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
f"Supported CUDA architectures are: {valid_archs}.")
invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list:
warnings.warn(
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
f"({env_arch_list}). Supported CUDA architectures are: "
f"{valid_archs}.",
stacklevel=2)
return arch_list


# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if _is_hip():
rocm_arches = get_pytorch_rocm_arch()
NVCC_FLAGS += ["--offload-arch=" + arch for arch in rocm_arches]
else:
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()

if _is_cuda() and not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine.
Expand Down Expand Up @@ -286,17 +319,6 @@ def get_torch_arch_list() -> Set[str]:
"nvcc": NVCC_FLAGS_PUNICA,
},
))
elif _is_hip():
amd_archs = os.getenv("GPU_ARCHS")
if amd_archs is None:
amd_archs = get_amdgpu_offload_arch()
for arch in amd_archs.split(";"):
if arch not in ROCM_SUPPORTED_ARCHS:
raise RuntimeError(
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
f"amdgpu_arch_found: {arch}")
NVCC_FLAGS += [f"--offload-arch={arch}"]

elif _is_neuron():
neuronxcc_version = get_neuronxcc_version()

Expand Down
Loading