Skip to content

Commit

Permalink
Merge pull request #22 from gmarkall/check-envvars-in-patch-needed
Browse files Browse the repository at this point in the history
Move env var checks into patch_needed() function
  • Loading branch information
gmarkall authored Aug 17, 2022
2 parents d6b2505 + 529f7a5 commit 87c009e
Showing 1 changed file with 34 additions and 20 deletions.
54 changes: 34 additions & 20 deletions ptxcompiler/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,37 @@ def get_cubin(self, cc=None):
def patch_needed():
logger = get_logger()

# The patch is needed if the user explicitly forced it with an environment
# variable.
apply = os.getenv("PTXCOMPILER_APPLY_NUMBA_CODEGEN_PATCH")
if apply is not None:
logger.debug(f"PTXCOMPILER_APPLY_NUMBA_CODEGEN_PATCH={apply}")
try:
apply = int(apply)
except ValueError:
apply = False

if apply:
return True

# We should avoid checking whether the patch is needed if the user
# requested that we don't check (e.g. in a non-fork-safe environment)
check = os.getenv("PTXCOMPILER_CHECK_NUMBA_CODEGEN_PATCH_NEEDED")
if check is not None:
logger.debug(f"PTXCOMPILER_CHECK_NUMBA_CODEGEN_PATCH_NEEDED={check}")
try:
check = int(check)
except ValueError:
check = False
else:
check = True

if not check:
return False

# Check whether the patch is needed by comparing the driver and runtime
# versions - it is needed if the runtime version exceeds the driver
# version.
cp = subprocess.run([sys.executable, '-c', CMD], capture_output=True)

if cp.returncode:
Expand All @@ -128,26 +159,9 @@ def patch_needed():

def patch_numba_codegen_if_needed():
logger = get_logger()
check = os.getenv("PTXCOMPILER_CHECK_NUMBA_CODEGEN_PATCH_NEEDED")
apply = os.getenv("PTXCOMPILER_APPLY_NUMBA_CODEGEN_PATCH")
if check is not None:
logger.debug(f"PTXCOMPILER_CHECK_NUMBA_CODEGEN_PATCH_NEEDED={check}")
try:
check = int(check)
except ValueError:
check = False
else:
check = True
if apply is not None:
logger.debug(f"PTXCOMPILER_APPLY_NUMBA_CODEGEN_PATCH={apply}")
try:
apply = int(apply)
except ValueError:
apply = False
else:
apply = False
if apply or (check and patch_needed()):

if patch_needed():
logger.debug("Patching Numba codegen for forward compatibility")
codegen.JITCUDACodegen._library_class = PTXStaticCompileCodeLibrary
else:
logger.debug("Driver version sufficient: not patching Numba codegen")
logger.debug("Not patching Numba codegen")

0 comments on commit 87c009e

Please sign in to comment.