Skip to content

Commit

Permalink
[Target] Improve TVM Target related items (apache#45)
Browse files Browse the repository at this point in the history
* Refactor target_detector.py to improve CUDA target handling

* chore: Add get_all_nvidia_targets function to target_detector.py, and refactor the device mismatch print
  • Loading branch information
LeiWang1999 authored Jun 1, 2024
1 parent c570a76 commit efab450
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/bitblas/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Licensed under the MIT License.
from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4 # noqa: F401
from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401
from .target_detector import auto_detect_nvidia_target # noqa: F401
from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401
18 changes: 16 additions & 2 deletions python/bitblas/utils/target_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
# Licensed under the MIT License.

import subprocess
from typing import List
from thefuzz import process
from tvm.target import Target
from tvm.target.tag import list_tags

import logging

logger = logging.getLogger(__name__)

TARGET_MISSING_ERROR = (
"TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=<target>`, "
"where <target> is one of the available targets can be found in the output of `tools/get_available_targets.py`."
)

def get_gpu_model_from_nvidia_smi():
"""
Expand Down Expand Up @@ -41,13 +47,21 @@ def find_best_match(tags, query):
def check_target(best, default):
return best if Target(best).arch == Target(default).arch else default

if check_target(best_match, "cuda"):
if check_target(best_match, "cuda") == best_match:
return best_match if score >= MATCH_THRESHOLD else "cuda"
else:
logger.info(f"Best match '{best_match}' is not a valid CUDA target, falling back to 'cuda'")
logger.warning(TARGET_MISSING_ERROR)
return "cuda"


def get_all_nvidia_targets() -> List[str]:
"""
Returns all available NVIDIA targets.
"""
all_tags = list_tags()
return [tag for tag in all_tags if "nvidia" in tag]


def auto_detect_nvidia_target() -> str:
"""
Automatically detects the NVIDIA GPU architecture to set the appropriate TVM target.
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ scipy
tornado
torch
thefuzz
tabulate
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ scipy
tornado
torch
thefuzz
tabulate
17 changes: 17 additions & 0 deletions tools/get_available_targets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas.utils import get_all_nvidia_targets
from tabulate import tabulate

def main():
# Get all available Nvidia targets
targets = get_all_nvidia_targets()

# Print available targets to console in a table format
table = [[i + 1, target] for i, target in enumerate(targets)]
headers = ["Index", "Target"]
print(tabulate(table, headers, tablefmt="pretty"))

if __name__ == "__main__":
main()

0 comments on commit efab450

Please sign in to comment.