diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index 2d0d668672..fb72a8b806 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -96,7 +96,7 @@ def _flashinfer(target) -> bool: return False arch_list = detect_cuda_arch_list(target) for arch in arch_list: - if int(re.findall(r"\d+", arch)[0]) < 80: + if arch < 80: logger.warning("flashinfer is not supported on CUDA arch < 80") return False return True diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 3cf49c43ba..d60be66ced 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -293,14 +293,19 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): return build -def detect_cuda_arch_list(target: Target) -> List[str]: +def detect_cuda_arch_list(target: Target) -> List[int]: """Detect the CUDA architecture list from the target.""" + def convert_to_num(arch_str): + arch_num_str = ''.join(filter(str.isdigit, arch_str)) + assert arch_num_str, f"'{arch_str}' does not contain any digits" + return int(arch_num_str) + assert target.kind.name == "cuda", f"Expect target to be CUDA, but got {target}" if MLC_MULTI_ARCH is not None: - multi_arch = [x.strip() for x in MLC_MULTI_ARCH.split(",")] + multi_arch = [convert_to_num(x) for x in MLC_MULTI_ARCH.split(",")] else: assert target.arch.startswith("sm_") - multi_arch = [target.arch[3:]] + multi_arch = [convert_to_num(target.arch[3:])] multi_arch = list(set(multi_arch)) return multi_arch