Skip to content

Commit

Permalink
Simplified cmake file so no definitions are required by default (#684)
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick authored Jul 28, 2024
1 parent ab0c51b commit c18cb48
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 51 deletions.
31 changes: 10 additions & 21 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -172,21 +172,16 @@ set(WARN_FLAGS ${WARN_FLAGS} $<$<COMPILE_LANGUAGE:CXX>:-Werror>)
# endif()

# CUTLASS support is not maintained. Remove the option to avoid confusion
set (CUTLASS_INC "")
target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=0)

if (MATX_NVTX_FLAGS)
add_definitions(-DMATX_NVTX_FLAGS)
target_compile_definitions(matx INTERFACE MATX_NVTX_FLAGS)
endif()
if (MATX_BUILD_32_BIT)
set(INT_TYPE "lp64")
add_definitions(-DINDEX_32_BIT)
target_compile_definitions(matx INTERFACE INDEX_32_BIT)
else()
set(INT_TYPE "ilp64")
add_definitions(-DINDEX_64_BIT)
target_compile_definitions(matx INTERFACE INDEX_64_BIT)
endif()

# Host support
Expand All @@ -195,7 +190,7 @@ if (MATX_EN_NVPL OR MATX_EN_X86_FFTW OR MATX_EN_BLIS OR MATX_EN_OPENBLAS)
find_package(OpenMP REQUIRED)
target_link_libraries(matx INTERFACE OpenMP::OpenMP_CXX)
target_compile_options(matx INTERFACE ${OpenMP_CXX_FLAGS})
target_compile_definitions(matx INTERFACE MATX_EN_OMP=1)
target_compile_definitions(matx INTERFACE MATX_EN_OMP)

set(BLAS_FLAGS MATX_EN_NVPL MATX_EN_BLIS MATX_EN_OPENBLAS)
set(ENABLED_BLAS_COUNT 0)
Expand All @@ -215,7 +210,7 @@ if (MATX_EN_NVPL OR MATX_EN_X86_FFTW OR MATX_EN_BLIS OR MATX_EN_OPENBLAS)
target_compile_definitions(matx INTERFACE NVPL_ILP64)
endif()
target_link_libraries(matx INTERFACE nvpl::fftw nvpl::blas_${INT_TYPE}_omp)
target_compile_definitions(matx INTERFACE MATX_EN_NVPL=1)
target_compile_definitions(matx INTERFACE MATX_EN_NVPL)
else()
# FFTW
if (MATX_EN_X86_FFTW)
Expand All @@ -225,26 +220,26 @@ if (MATX_EN_NVPL OR MATX_EN_X86_FFTW OR MATX_EN_BLIS OR MATX_EN_OPENBLAS)
find_library(FFTW_OMP_LIB fftw3_omp REQUIRED)
find_library(FFTWF_OMP_LIB fftw3f_omp REQUIRED)
target_link_libraries(matx INTERFACE ${FFTW_LIB} ${FFTWF_LIB} ${FFTW_OMP_LIB} ${FFTWF_OMP_LIB})
target_compile_definitions(matx INTERFACE MATX_EN_X86_FFTW=1)
target_compile_definitions(matx INTERFACE MATX_EN_X86_FFTW)
endif()

# BLAS
if (MATX_EN_BLIS)
message(STATUS "Enabling BLIS")
include(cmake/FindBLIS.cmake)
target_link_libraries(matx INTERFACE BLIS::BLIS)
target_compile_definitions(matx INTERFACE MATX_EN_BLIS=1)
target_compile_definitions(matx INTERFACE MATX_EN_BLIS)
elseif(MATX_EN_OPENBLAS)
message(STATUS "Enabling OpenBLAS")
include(cmake/FindOpenBLAS.cmake)
target_link_libraries(matx INTERFACE OpenBLAS::OpenBLAS)
target_compile_definitions(matx INTERFACE MATX_EN_OPENBLAS=1)
target_compile_definitions(matx INTERFACE MATX_EN_OPENBLAS)
endif()
endif()
endif()

if (MATX_DISABLE_CUB_CACHE)
target_compile_definitions(matx INTERFACE MATX_DISABLE_CUB_CACHE=1)
target_compile_definitions(matx INTERFACE MATX_DISABLE_CUB_CACHE)
endif()

if (MATX_EN_COVERAGE)
Expand All @@ -259,16 +254,14 @@ if (MATX_EN_CUTENSOR)

include(cmake/FindcuTENSOR.cmake)
include(cmake/FindcuTensorNet.cmake)
target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTENSOR=1)
target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTENSOR)

target_link_libraries(matx INTERFACE cuTENSOR::cuTENSOR)
target_link_libraries(matx INTERFACE cuTensorNet::cuTensorNet)

# CUDA toolkit and most accompanying libraries like cuTENSOR use the old rpath instead of RUNPATH.
# We switch to that format here for compatibility
target_link_libraries(matx INTERFACE "-Wl,--disable-new-dtags")
else()
target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTENSOR=0)
endif()

if (MATX_MULTI_GPU)
Expand All @@ -280,8 +273,8 @@ endif()
if (MATX_EN_FILEIO OR MATX_EN_VISUALIZATION OR MATX_EN_PYBIND11 OR MATX_BUILD_EXAMPLES OR MATX_BUILD_TESTS OR MATX_BUILD_BENCHMARKS)
message(STATUS "Enabling pybind11 support")
set(MATX_EN_PYBIND11 ON)
target_compile_definitions(matx INTERFACE MATX_ENABLE_PYBIND11=1)
target_compile_definitions(matx INTERFACE MATX_ENABLE_FILEIO=1)
target_compile_definitions(matx INTERFACE MATX_ENABLE_PYBIND11)
target_compile_definitions(matx INTERFACE MATX_ENABLE_FILEIO)
target_compile_options(matx INTERFACE -DMATX_ROOT="${PROJECT_SOURCE_DIR}")

include(cmake/GetPyBind11.cmake)
Expand All @@ -301,15 +294,11 @@ if (MATX_EN_FILEIO OR MATX_EN_VISUALIZATION OR MATX_EN_PYBIND11 OR MATX_BUILD_EX

# Visualization requires Python libraries
if (MATX_EN_VISUALIZATION)
target_compile_definitions(matx INTERFACE MATX_ENABLE_VIZ=1)
target_compile_definitions(matx INTERFACE MATX_ENABLE_VIZ)
check_python_libs("plotly.express")
else()
target_compile_definitions(matx INTERFACE MATX_ENABLE_VIZ=0)
endif()
else()
message(WARNING "pybind11 support disabled. Visualizations and file IO will be disabled")
target_compile_definitions(matx INTERFACE MATX_ENABLE_PYBIND11=0)
target_compile_definitions(matx INTERFACE MATX_ENABLE_FILEIO=0)
endif()

# Add in all CUDA linker dependencies
Expand Down
13 changes: 3 additions & 10 deletions include/matx/core/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,12 @@

namespace matx {

#ifdef INDEX_64_BIT
using index_t = long long int;
#define INDEX_T_FMT "lld"
#endif

#ifdef INDEX_32_BIT
using index_t = int32_t;
#define INDEX_T_FMT "d"
#endif

#if ((defined(INDEX_64_BIT) && defined(INDEX_32_BIT)) || \
(!defined(INDEX_64_BIT) && !defined(INDEX_32_BIT)))
static_assert(false, "Must choose either 64-bit or 32-bit index mode");
#else
using index_t = long long int;
#define INDEX_T_FMT "lld"
#endif

#ifdef __CUDACC__
Expand Down
2 changes: 1 addition & 1 deletion include/matx/core/file_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@



#if MATX_ENABLE_FILEIO || DOXYGEN_ONLY
#if defined(MATX_ENABLE_FILEIO) || defined(DOXYGEN_ONLY)

namespace matx {
namespace io {
Expand Down
3 changes: 2 additions & 1 deletion include/matx/core/pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#include "matx/core/type_utils.h"
#include "matx/core/make_tensor.h"

#if MATX_ENABLE_PYBIND11
#ifdef MATX_ENABLE_PYBIND11

#include <pybind11/embed.h>
#include <pybind11/numpy.h>
Expand Down Expand Up @@ -93,6 +93,7 @@ class MatXPybind {
// Interpreter already running
}
}

AddPath(std::string(MATX_ROOT) + GENERATORS_PATH);
}

Expand Down
6 changes: 3 additions & 3 deletions include/matx/core/tensor_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,12 +487,12 @@ using tensor_desc_cr_disi_dist = tensor_desc_cr_ds_t<index_t, index_t, RANK>;
*
* @tparam RANK Rank of shape
*/
#ifdef INDEX_64_BIT
#ifdef INDEX_32_BIT
template <int RANK>
using DefaultDescriptor = tensor_desc_cr_ds_64_64_t<RANK>;
using DefaultDescriptor = tensor_desc_cr_ds_32_32_t<RANK>;
#else
template <int RANK>
using DefaultDescriptor = tensor_desc_cr_ds_32_32_t<RANK>;
using DefaultDescriptor = tensor_desc_cr_ds_64_64_t<RANK>;
#endif

}; // namespace matx
2 changes: 1 addition & 1 deletion include/matx/core/viz.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#include "matx/core/tensor.h"
#include "matx/core/pybind.h"

#if MATX_ENABLE_VIZ || DOXYGEN_ONLY
#if defined(MATX_ENABLE_VIZ) || defined(DOXYGEN_ONLY)

namespace matx {
namespace viz {
Expand Down
6 changes: 3 additions & 3 deletions include/matx/executors/host.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,19 @@ class HostExecutor {
n_threads = 1;
}
else if constexpr (MODE == ThreadsMode::ALL) {
#if MATX_EN_OMP
#ifdef MATX_EN_OMP
n_threads = omp_get_num_procs();
#endif
}
params_ = HostExecParams(n_threads);

#if MATX_EN_OMP
#ifdef MATX_EN_OMP
omp_set_num_threads(params_.GetNumThreads());
#endif
}

HostExecutor(const HostExecParams &params) : params_(params) {
#if MATX_EN_OMP
#ifdef MATX_EN_OMP
omp_set_num_threads(params_.GetNumThreads());
#endif
}
Expand Down
4 changes: 2 additions & 2 deletions include/matx/transforms/einsum.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

#pragma once

#if MATX_ENABLE_CUTENSOR
#ifdef MATX_ENABLE_CUTENSOR
#include <cstdio>
#include <numeric>
#include "error.h"
Expand Down Expand Up @@ -486,7 +486,7 @@ namespace cutensor {
template <typename OutputType, typename... InT>
void einsum_impl([[maybe_unused]] OutputType &out, [[maybe_unused]] const std::string &subscripts, [[maybe_unused]] cudaStream_t stream, [[maybe_unused]] InT... tensors)
{
#if MATX_ENABLE_CUTENSOR
#ifdef MATX_ENABLE_CUTENSOR
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

// Get parameters required by these tensors
Expand Down
15 changes: 6 additions & 9 deletions include/matx/transforms/matmul/matmul_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

#include <cublasLt.h>

#if MATX_ENABLE_CUTLASS == 1
#ifdef MATX_ENABLE_CUTLASS
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_batched.h"
#endif
Expand Down Expand Up @@ -184,9 +184,6 @@ class MatMulCUDAHandle_t {
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

MATX_STATIC_ASSERT_STR((PROV != PROVIDER_TYPE_CUTLASS) || MATX_ENABLE_CUTLASS, matxMatMulError,
"Must use -DCUTLASS_DIR in CMake to enable CUTLASS support");

static_assert(RANK >= 2);
MATX_ASSERT(a.Size(TensorTypeA::Rank() - 1) == b.Size(TensorTypeB::Rank() - 2), matxInvalidSize);
MATX_ASSERT(c.Size(RANK - 1) == b.Size(TensorTypeB::Rank() - 1), matxInvalidSize);
Expand Down Expand Up @@ -865,7 +862,7 @@ class MatMulCUDAHandle_t {

if constexpr (RANK == 2) {
if constexpr (PROV == PROVIDER_TYPE_CUTLASS) {
#if MATX_ENABLE_CUTLASS
#ifdef MATX_ENABLE_CUTLASS
using CutlassAOrder = std::conditional_t<OrderA == MEM_ORDER_ROW_MAJOR,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>;
Expand Down Expand Up @@ -909,7 +906,7 @@ class MatMulCUDAHandle_t {
}
else {
static_assert(RANK > 2);
#if MATX_ENABLE_CUTLASS
#ifdef MATX_ENABLE_CUTLASS
using CutlassAOrder = std::conditional_t<OrderA == MEM_ORDER_ROW_MAJOR,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>;
Expand All @@ -930,7 +927,7 @@ class MatMulCUDAHandle_t {

if constexpr (RANK > 3) {
if constexpr (PROV == PROVIDER_TYPE_CUTLASS) {
#if MATX_ENABLE_CUTLASS
#ifdef MATX_ENABLE_CUTLASS
for (size_t iter = 0; iter < total_iter; iter++) {
// Get pointers into A/B/C for this round
auto ap = cuda::std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, idx);
Expand Down Expand Up @@ -1004,7 +1001,7 @@ class MatMulCUDAHandle_t {
beta);
}
else if (c.Stride(RANK - 2) <= 1) {
#if MATX_ENABLE_CUTLASS
#ifdef MATX_ENABLE_CUTLASS
MatMulLaunch<OrderA, OrderB, MEM_ORDER_COL_MAJOR>(a, b, c, stream, alpha,
beta);
#else
Expand Down Expand Up @@ -1204,7 +1201,7 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A,
(c = C).run(stream);
}

#if MATX_ENABLE_CUTLASS != 1
#ifndef MATX_ENABLE_CUTLASS
// cublasLt does not allow transpose modes on C. Thus we need to make sure that the right most dimension has a stride of 1.
// Use the identity CT = BT * AT to do the transpose through the gemm automatically. Note we only want to do this transpose if
// the rightmost stride is !=1 or this function will be an infinite recursion.
Expand Down

0 comments on commit c18cb48

Please sign in to comment.