Skip to content

Commit

Permalink
[Fix] CUDA architecture detection bug fix
Browse files Browse the repository at this point in the history
This commit returns a list of integers and adds an assert to check that the string of CUDA architecture must contain numbers only.
  • Loading branch information
kobeyu committed Apr 24, 2024
1 parent 9ec75ee commit 23bbea6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/mlc_llm/interface/compiler_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _flashinfer(target) -> bool:
return False
arch_list = detect_cuda_arch_list(target)
for arch in arch_list:
if int(re.findall(r"\d+", arch)[0]) < 80:
if arch < 80:
logger.warning("flashinfer is not supported on CUDA arch < 80")
return False
return True
Expand Down
12 changes: 9 additions & 3 deletions python/mlc_llm/support/auto_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,20 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None):
return build


def detect_cuda_arch_list(target: Target) -> List[str]:
def detect_cuda_arch_list(target: Target) -> List[int]:
"""Detect the CUDA architecture list from the target."""

def convert_to_num(arch_str):
arch_num_str = "".join(filter(str.isdigit, arch_str))
assert arch_num_str, f"'{arch_str}' does not contain any digits"
return int(arch_num_str)

assert target.kind.name == "cuda", f"Expect target to be CUDA, but got {target}"
if MLC_MULTI_ARCH is not None:
multi_arch = [x.strip() for x in MLC_MULTI_ARCH.split(",")]
multi_arch = [convert_to_num(x) for x in MLC_MULTI_ARCH.split(",")]
else:
assert target.arch.startswith("sm_")
multi_arch = [target.arch[3:]]
multi_arch = [convert_to_num(target.arch[3:])]
multi_arch = list(set(multi_arch))
return multi_arch

Expand Down

0 comments on commit 23bbea6

Please sign in to comment.