Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows native port #2478

Merged
merged 56 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
1b87208
First commit, triton builds on windows
gshimansky Oct 4, 2024
62cbe2c
Fixed loading of nvidia driver.py module
gshimansky Oct 5, 2024
022402d
Fixed compile command line for windows compilers
gshimansky Oct 9, 2024
69dee26
Make loader buildable on windows
gshimansky Oct 10, 2024
9382773
Fixed loading spirv_utils dynamic library module
gshimansky Oct 11, 2024
dbc280a
Fixed getting device context on windows
gshimansky Oct 11, 2024
bcf7d94
Formatter corrections
gshimansky Oct 11, 2024
f43486d
Disable compilation warning psapi
gshimansky Oct 15, 2024
e3e8068
Specify different levels of C++ for Linux and Windows
gshimansky Oct 15, 2024
af38761
Fixed constants of different size on linux and windows
gshimansky Oct 15, 2024
33fc428
Fixed -fPIC
gshimansky Oct 15, 2024
f1c0970
Fixed print_helper cuda windows addition
gshimansky Oct 16, 2024
fb33d3d
Removed superfluous C++ standard options
gshimansky Oct 16, 2024
c992c68
Restored comment
gshimansky Oct 16, 2024
9a68984
Removed torch.cuda.synchronize and delete=False for temp files
gshimansky Oct 17, 2024
3b2ebf6
Restored original test code
gshimansky Oct 17, 2024
6c47d44
Restored delete=False when creating a tempfile
gshimansky Oct 17, 2024
37a7e8e
Use sysconfig instead of condition
gshimansky Oct 17, 2024
fe2116b
Take commit 188370325395c79a4ba3de0bc47e39a19fc83224 from upstream tr…
gshimansky Oct 18, 2024
726dab6
Revert hashing algorythm to what it was
gshimansky Oct 18, 2024
7c5c234
Changed condition to sysconfig call
gshimansky Oct 18, 2024
a3268ea
Removed redundant PARTIAL_SOURCES_INTENDED
gshimansky Oct 21, 2024
62774dc
Manually delete tempfile after it is not needed any more
gshimansky Oct 22, 2024
ae77f77
Merge branch 'main' into gregory/windows-support
gshimansky Oct 22, 2024
34589f1
Use /Zc:preprocessor instead of MSVC workarounds
gshimansky Oct 22, 2024
14bc5c2
Use command list instead of string
gshimansky Oct 22, 2024
97d2441
Merge branch 'main' into gregory/windows-support
gshimansky Oct 23, 2024
835497a
Merge branch 'main' into gregory/windows-support
gshimansky Oct 25, 2024
6b5a1d4
Remove change that is already implemented in 9d424e02ed4db695cc58baf9…
gshimansky Oct 25, 2024
39d252a
Remove change because __builtin_prefetch is no longer called on Windows
gshimansky Oct 25, 2024
14f01c8
Remove ifdef because AMD headers build on windows successfully
gshimansky Oct 25, 2024
ec7dff7
Merge branch 'main' into gregory/windows-support
gshimansky Oct 28, 2024
f57b7bf
Merge branch 'main' into gregory/windows-support
gshimansky Oct 28, 2024
42a6307
Removed windows llvm URL arch because we don't have llvm windows bina…
gshimansky Oct 29, 2024
79937ce
Removed ifdef around AMD calls because they successfully compile on w…
gshimansky Oct 29, 2024
34f2dce
Use 'long long' type for int64_t
gshimansky Oct 29, 2024
506ee92
Removed debug print
gshimansky Oct 30, 2024
dc8a628
Merge branch 'main' into gregory/windows-support
gshimansky Oct 30, 2024
81b009a
Removed cuobjdump.exe, nvdisasm.exe and ptxas.exe from ignore
gshimansky Oct 30, 2024
1c985de
Merge branch 'main' into gregory/windows-support
gshimansky Oct 31, 2024
cb3bbf9
Merge branch 'main' into gregory/windows-support
gshimansky Oct 31, 2024
fad8e24
Merge branch 'main' into gregory/windows-support
gshimansky Nov 1, 2024
e9db820
Merge branch 'main' into gregory/windows-support
gshimansky Nov 6, 2024
289b1fe
Removed redundant code
gshimansky Nov 6, 2024
74a1c93
Merge branch 'main' into gregory/windows-support
gshimansky Nov 7, 2024
a8cd2fa
Removed redundant code
gshimansky Nov 7, 2024
12a097d
Merge branch 'main' into gregory/windows-support
gshimansky Nov 7, 2024
d15586c
Removed redundant code
gshimansky Nov 8, 2024
36a1977
Merge branch 'main' into gregory/windows-support
gshimansky Nov 8, 2024
916613a
Merge branch 'main' into gregory/windows-support
gshimansky Nov 8, 2024
f017395
Remove empty line
gshimansky Nov 8, 2024
fdb63be
Merge branch 'main' into gregory/windows-support
gshimansky Nov 12, 2024
d940e61
Merge branch 'main' into gregory/windows-support
gshimansky Nov 14, 2024
55babb3
Merge branch 'main' into gregory/windows-support
gshimansky Nov 18, 2024
4a47656
Added Windows system to allow downloading llvm binary
gshimansky Nov 19, 2024
c44eebd
Merge branch 'main' into gregory/windows-support
gshimansky Nov 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ python/*.whl
python/triton/_C/*.pyd
python/triton/_C/*.so
python/triton/_C/*.dylib
python/triton/_C/*.pdb
python/triton/_C/*.exe
python/triton/_C/*.ilk

benchmarks/dist
benchmarks/*.egg-info/
Expand Down
65 changes: 49 additions & 16 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,24 @@ endif()

include(ExternalProject)

set(CMAKE_CXX_STANDARD 17)

set(CMAKE_INCLUDE_CURRENT_DIR ON)

project(triton CXX)
include(CTest)

if(NOT WIN32)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
endif()


list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")

# Options
if(WIN32)
set(DEFAULT_BUILD_PROTON OFF)
else()
set(DEFAULT_BUILD_PROTON ON)
endif()

# Define the option with the determined default value
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ${DEFAULT_BUILD_PROTON})
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
Expand All @@ -49,10 +50,21 @@ endif()
# used conditionally in this file and by lit tests

# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1")
set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1")
if(NOT MSVC)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1")
set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1")
else()
set(CMAKE_CXX_STANDARD 20)
victor-eds marked this conversation as resolved.
Show resolved Hide resolved
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
Comment on lines +61 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gshimansky these flags are for debug build, aren't?

Suggested change
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /RTC1 /bigobj /Zc:preprocessor")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /RTC1 /bigobj /Zc:preprocessor")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. At the moment there is just one windows build and it is debug build mostly. We can split it into optimized and debug build like we have on Linux.

set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
endif()

# Default build type
if(NOT CMAKE_BUILD_TYPE)
Expand All @@ -70,7 +82,15 @@ endif()

# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
if(NOT MSVC)
if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why -std=gnu++17 is removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a redundant option as was discussed earlier in this PR. C++ level is defined by CMAKE_CXX_STANDARD and we had it defined here https://github.com/intel/intel-xpu-backend-for-triton/pull/2478/files#diff-1e7de1ae2d059d21e1dd75d5812d5a34b0222cef273b7c3a2af62eb747f9d20aL11 but now it was moved into condition for Linux/Windows because MSVC is unable to parse some of the templates on level 17 so it needs 20, while gcc refuses to compile code on level 20 https://github.com/intel/intel-xpu-backend-for-triton/pull/2478/files#diff-1e7de1ae2d059d21e1dd75d5812d5a34b0222cef273b7c3a2af62eb747f9d20aR37 .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this change upstream?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't set(CMAKE_CXX_STANDARD 17) correspond to -std=c++17, not -std=gnu++17?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is -std=gnu++17 necessary? Looks like all tests in this PR pass without it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is -std=gnu++17 necessary? Looks like all tests in this PR pass without it.

The Triton project has been using this extension for a very long time, for example I found a mention of gnu++11 in triton-lang/triton@50587bb. Even if we assume that all code, not just the one being tested, does not use this extension, it is unlikely that they will decide to remove something that they have been using for a long time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it might be some legacy like Windows build support that Triton inherited from previous projects and by now nobody knows why it was added.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked. Explicitly usage of -std=c++17 brokes build: #2771. Looks like at least on GCC set(CMAKE_CXX_STANDARD 17) add -std=gnu++17 implicitly.

else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -Wno-deprecated")
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530")
endif()


# #########
Expand Down Expand Up @@ -124,7 +144,11 @@ endfunction()


# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-")
endif()

include_directories(".")
include_directories(${MLIR_INCLUDE_DIRS})
Expand All @@ -134,7 +158,8 @@ include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
include_directories(${PROJECT_SOURCE_DIR}/third_party)
include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files

# link_directories(${LLVM_LIBRARY_DIR})
link_directories(${LLVM_LIBRARY_DIR})

add_subdirectory(include)
add_subdirectory(lib)

Expand Down Expand Up @@ -163,6 +188,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
# using pip install.
include_directories(${PYTHON_INCLUDE_DIRS})
include_directories(${PYBIND11_INCLUDE_DIR})
message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}")
link_directories(${PYTHON_LIB_DIRS})
else()
# Otherwise, we might be building from top CMakeLists.txt directly.
# Try to find Python and pybind11 packages.
Expand Down Expand Up @@ -245,7 +272,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
LLVMAArch64CodeGen
LLVMAArch64AsmParser
)
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64")
list(APPEND TRITON_LIBRARIES
LLVMX86CodeGen
LLVMX86AsmParser
Expand Down Expand Up @@ -280,6 +307,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES})
if(WIN32)
target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS})
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
set_target_properties(triton PROPERTIES PREFIX "lib")
else()
target_link_libraries(triton PRIVATE z)
endif()
Expand All @@ -306,6 +335,10 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()
endif()
if(WIN32)
option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON)
option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON)
endif()

add_subdirectory(third_party/f2reduce)
add_subdirectory(bin)
Expand Down
56 changes: 54 additions & 2 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,49 @@ def copy_externals():
]


def find_vswhere():
program_files = os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)")
vswhere_path = Path(program_files) / "Microsoft Visual Studio" / "Installer" / "vswhere.exe"
if vswhere_path.exists():
return vswhere_path
return None
victor-eds marked this conversation as resolved.
Show resolved Hide resolved


def find_visual_studio(version_ranges):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse code from CLFinder.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can and it should work in theory, but it doesn't. For some reason when I import these functions from CLFinder I am getting an error LINK : fatal error LNK1168: cannot open C:\b\tr\python\triton\_C\libtriton.pyd for writing. I have no idea how they are related but this link error is stably reproducible for me. I tried to debug the problem and found that environment after calling set_env_vars is identical when functions are reused from CLFinder.py or setup.py has its own copies, so now I am out of ideas how libtryton.pyd may end up locked.

vswhere = find_vswhere()
if not vswhere:
raise FileNotFoundError("vswhere.exe not found.")

for version_range in version_ranges:
command = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me it works only if I specify -products:
"C:\Program Files (x86)\Microsoft Visual Studio\Installer\vswhere.exe" -version "[17.0,18.0)" -products Microsoft.VisualStudio.Product.BuildTools -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath -prerelease

@gshimansky do you know why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me it works just fine if I don't specify -products or if I specify -products "*". Specifying -products Microsoft.VisualStudio.Product.BuildTools doesn't find anything for me but specifying -products Microsoft.VisualStudio.Product.Professional does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-products "*" works for me too. It seems that this is because I did not install the entire studio, but only the build tools. Should we add -products "*" to allow this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know whether it may result in multiple different products found as the result, but we can add it for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know whether it may result in multiple different products found as the result, but we can add it for now.

I thought about that too. But couldn't it be the same default behavior when -products parameter is not specified at all?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The easiest way at the moment is to just take the first product if there are several: return output.split("\n")[0] (as in #2744).

str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
"-property", "installationPath", "-prerelease"
]

try:
output = subprocess.check_output(command, text=True).strip()
if output:
return output
except subprocess.CalledProcessError:
continue

return None


def set_env_vars(vs_path, arch="x64"):
vcvarsall_path = Path(vs_path) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat"
if not vcvarsall_path.exists():
raise FileNotFoundError(f"vcvarsall.bat not found in expected path: {vcvarsall_path}")

command = ["call", vcvarsall_path, arch, "&&", "set"]
output = subprocess.check_output(command, shell=True, text=True)

for line in output.splitlines():
if '=' in line:
var, value = line.split('=', 1)
os.environ[var] = value


# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
def check_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
Expand Down Expand Up @@ -281,10 +324,10 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func):
base_dir = os.path.dirname(__file__)
system = platform.system()
try:
arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
arch = {"x86_64": "64", "AMD64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
except KeyError:
arch = platform.machine()
supported = {"Linux": "linux", "Darwin": "linux"}
supported = {"Linux": "linux", "Darwin": "linux", "Windows": "win"}
url = url_func(supported[system], arch, version)
tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download
dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path
Expand Down Expand Up @@ -401,6 +444,11 @@ def get_proton_cmake_args(self):
def build_extension(self, ext):
lit_dir = shutil.which('lit')
ninja_dir = shutil.which('ninja')
if platform.system() == "Windows":
vs_path = find_visual_studio(["[17.0,18.0)", "[16.0,17.0)"])
env = set_env_vars(vs_path)
if not vs_path:
raise EnvironmentError("Visual Studio 2019 or 2022 not found.")
Comment on lines +451 to +453
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably a good idea to define initialize_visual_studio_env as in CLFinder.py file here as well.

Suggested change
env = set_env_vars(vs_path)
if not vs_path:
raise EnvironmentError("Visual Studio 2019 or 2022 not found.")
if not vs_path:
raise EnvironmentError("Visual Studio 2019 or 2022 not found.")
env = set_env_vars(vs_path)

# lit is used by the test suite
thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()])
thirdparty_cmake_args += self.get_pybind11_cmake_args()
Expand All @@ -421,6 +469,10 @@ def build_extension(self, ext):
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
]
if platform.system() == "Windows":
installed_base = sysconfig.get_config_var('installed_base')
py_lib_dirs = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))
cmake_args.append("-DPYTHON_LIB_DIRS=" + py_lib_dirs)
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
cmake_args.extend(thirdparty_cmake_args)
Expand Down
55 changes: 55 additions & 0 deletions python/triton/runtime/CLFinder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import subprocess
from pathlib import Path


def find_vswhere():
program_files = os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)")
vswhere_path = Path(program_files) / "Microsoft Visual Studio" / "Installer" / "vswhere.exe"
if vswhere_path.exists():
return vswhere_path
return None


def find_visual_studio(version_ranges):
vswhere = find_vswhere()
if not vswhere:
raise FileNotFoundError("vswhere.exe not found.")

for version_range in version_ranges:
command = [
str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
"-property", "installationPath", "-prerelease"
]

try:
output = subprocess.check_output(command, text=True).strip()
if output:
return output
except subprocess.CalledProcessError:
continue

return None


def set_env_vars(vs_path, arch="x64"):
vcvarsall_path = Path(vs_path) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat"
if not vcvarsall_path.exists():
raise FileNotFoundError(f"vcvarsall.bat not found in expected path: {vcvarsall_path}")

command = f'call "{vcvarsall_path}" {arch} && set'
output = subprocess.check_output(command, shell=True, text=True)

for line in output.splitlines():
if '=' in line:
var, value = line.split('=', 1)
os.environ[var] = value


def initialize_visual_studio_env(version_ranges, arch="x64"):
# Check if the environment variable that vcvarsall.bat sets is present
if os.environ.get('VSCMD_ARG_TGT_ARCH') != arch:
Comment on lines +50 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point vcvarsall.bat has not yet been called? This will only happen when calling set_env_vars function IIUC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is to verify that we're not already running in a Developer Command Prompt environment set outside of setup.py, and if we're running, we're running in environment for the same architecture (32 or 64 bits) that we're going to build here.

vs_path = find_visual_studio(version_ranges)
if not vs_path:
raise EnvironmentError("Visual Studio not found in specified version ranges.")
set_env_vars(vs_path, arch)
45 changes: 36 additions & 9 deletions python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import shutil
import subprocess
import setuptools
import platform
from .CLFinder import initialize_visual_studio_env


def is_xpu():
Expand All @@ -23,6 +25,29 @@ def quiet():
sys.stdout, sys.stderr = old_stdout, old_stderr


def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
if cc in ["cl", "clang-cl"]:
cc_cmd = [cc, src, "/nologo", "/O2", "/LD"]
cc_cmd += [f"/I{dir}" for dir in include_dirs]
cc_cmd += [f"/Fo{os.path.join(os.path.dirname(out), 'main.obj')}"]
cc_cmd += ["/link"]
cc_cmd += [f"/OUT:{out}"]
cc_cmd += [f"/IMPLIB:{os.path.join(os.path.dirname(out), 'main.lib')}"]
cc_cmd += [f"/PDB:{os.path.join(os.path.dirname(out), 'main.pdb')}"]
cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs]
cc_cmd += [f'{lib}.lib' for lib in libraries]
else:
cc_cmd = [cc, src, "-O3", "-shared", "-Wno-psabi"]
if os.name != "nt":
cc_cmd += ["-fPIC"]
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs]
cc_cmd += ["-o", out]

return cc_cmd


def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compile_args=[]):
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
Expand All @@ -33,6 +58,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi
clang = shutil.which("clang")
gcc = shutil.which("gcc")
cc = gcc if gcc is not None else clang
if platform.system() == "Windows":
cc = "cl"
initialize_visual_studio_env(["[17.0,18.0)", "[16.0,17.0)"])
if cc is None:
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
# This function was renamed and made public in Python 3.10
Expand All @@ -55,25 +83,24 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi
clangpp = shutil.which("clang++")
gxx = shutil.which("g++")
icpx = shutil.which("icpx")
cxx = icpx or clangpp or gxx
cxx = icpx if os.name == "nt" else icpx or clangpp or gxx
if cxx is None:
raise RuntimeError("Failed to find C++ compiler. Please specify via CXX environment variable.")
cc = cxx
import numpy as np
numpy_include_dir = np.get_include()
include_dirs = include_dirs + [numpy_include_dir]
cc_cmd = [cxx]
if icpx is not None:
cc_cmd += ["-fsycl"]
extra_compile_args += ["-fsycl"]
else:
cc_cmd += ["--std=c++17"]
extra_compile_args += ["--std=c++17"]
if os.name == "nt":
library_dirs += [os.path.join(sysconfig.get_paths(scheme=scheme)["stdlib"], "..", "libs")]
else:
cc_cmd = [cc]

# for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
cc_cmd += [src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so]
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
cc_cmd += extra_compile_args

if os.getenv("VERBOSE"):
Expand All @@ -90,7 +117,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3'],
extra_compile_args=extra_compile_args + ['-O3' if "-O3" in cc_cmd else "/O2"],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
Expand Down
Loading