Skip to content

Commit

Permalink
Add an option to hack the triton compiler to avoid backend query
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglx13 committed Jul 22, 2024
1 parent 7319adc commit a61967b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
11 changes: 11 additions & 0 deletions scripts/amd/gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,10 @@ def parse_args():
action='store_true',
default=False,
help="Whether we want to skip the compilation stage")
parser.add_argument("--hack_triton_compiler",
action='store_true',
default=False,
help="Modify the triton source to avoid backend query")
args = parser.parse_args()
if not args.o:
if args.benchmark:
Expand Down Expand Up @@ -691,6 +695,7 @@ def main():
jobs = args.jobs
iters = args.iters
skipWarmup = args.no_warmup
hack_triton = args.hack_triton_compiler

# Get GPU ids
ngpus = args.ngpus
Expand Down Expand Up @@ -778,6 +783,12 @@ def main():
run_bash_command("rm -rf ~/.triton/cache")
run_bash_command(f"rm -rf {get_filename_myKernels()}")

## Modify triton compiler
## Hacky !!!
if hack_triton:
patch_triton_compiler()


configs = []

## Big for loop of tuning
Expand Down
24 changes: 24 additions & 0 deletions scripts/amd/gemm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,27 @@ def get_default_tuning_result_filename():
path = os.path.dirname(os.path.abspath(__file__))
defaultName = f"{path}/../tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml"
return defaultName


def patch_triton_compiler():
device = triton.runtime.driver.active.get_current_device()
stream = triton.runtime.driver.active.get_current_stream(device)
target = triton.runtime.driver.active.get_current_target()

triton_location_str = run_bash_command("pip show triton | grep Editable")
if not triton_location_str:
print("triton source not found from pip show triton")

triton_dir = triton_location_str[0].split()[-1].decode('utf-8')

jit_filename = os.path.join(triton_dir, "triton/runtime", "jit.py")

run_bash_command(f"sed -i 's/driver.active.get_current_device()/{device}/g' {jit_filename}")
run_bash_command(f"sed -i 's/driver.active.get_current_stream(device)/{stream}/g' {jit_filename}")

hip_driver_filename = os.path.join(triton_dir, "../third_party/amd/backend/", "driver.py")
cuda_driver_filename = os.path.join(triton_dir, "../third_party/nvidia/backend/", "driver.py")

run_bash_command(f"sed -i 's/import torch/return True/g' {hip_driver_filename}")
run_bash_command(f"sed -i 's/device = self.get_current_device()/return GPUTarget(\"hip\", \"{target.arch}\", 64)/g' {hip_driver_filename}")
run_bash_command(f"sed -i 's/import torch/return False/g' {cuda_driver_filename}")

0 comments on commit a61967b

Please sign in to comment.