From 75d9c8f1399cfaf5125beae309b0559cb1835c69 Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Tue, 29 Aug 2023 21:23:54 +0300 Subject: [PATCH] [Arith] MLIR PresburgerSet compile fix mlir >= 160 --- CMakeLists.txt | 1 + cmake/modules/LibInfo.cmake | 7 +++++ cmake/utils/FindLLVM.cmake | 2 ++ src/arith/presburger_set.cc | 51 ++++++++++++++++++++++++++++++------- src/support/libinfo.cc | 6 +++++ 5 files changed, 58 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b69145806ad..3037d40ec5a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,6 +54,7 @@ tvm_option(USE_HEXAGON_EXTERNAL_LIBS "Path to git repo containing external Hexag tvm_option(USE_RPC "Build with RPC" ON) tvm_option(USE_THREADS "Build with thread support" ON) tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" OFF) +tvm_option(USE_MLIR "Build with MLIR support" OFF) tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_EXECUTOR "Build with tiny graph executor" ON) tvm_option(USE_GRAPH_EXECUTOR_CUDA_GRAPH "Build with tiny graph executor with CUDA Graph for GPUs" OFF) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index ad153cce0455..fd12b9d0386d 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -24,6 +24,11 @@ function(add_lib_info src_file) else() string(STRIP ${TVM_INFO_LLVM_VERSION} TVM_INFO_LLVM_VERSION) endif() + if (NOT DEFINED TVM_INFO_MLIR_VERSION) + set(TVM_INFO_MLIR_VERSION "NOT-FOUND") + else() + string(STRIP ${TVM_INFO_MLIR_VERSION} TVM_INFO_MLIR_VERSION) + endif() if (NOT DEFINED CUDA_VERSION) set(TVM_INFO_CUDA_VERSION "NOT-FOUND") else() @@ -47,6 +52,7 @@ function(add_lib_info src_file) TVM_INFO_INDEX_DEFAULT_I64="${INDEX_DEFAULT_I64}" TVM_INFO_INSTALL_DEV="${INSTALL_DEV}" TVM_INFO_LLVM_VERSION="${TVM_INFO_LLVM_VERSION}" + TVM_INFO_MLIR_VERSION="${TVM_INFO_MLIR_VERSION}" TVM_INFO_PICOJSON_PATH="${PICOJSON_PATH}" TVM_INFO_RANG_PATH="${RANG_PATH}" TVM_INFO_ROCM_PATH="${ROCM_PATH}" @@ -86,6 +92,7 @@ function(add_lib_info src_file) TVM_INFO_USE_LIBBACKTRACE="${USE_LIBBACKTRACE}" TVM_INFO_USE_LIBTORCH="${USE_LIBTORCH}" TVM_INFO_USE_LLVM="${USE_LLVM}" + TVM_INFO_USE_MLIR="${USE_MLIR}" TVM_INFO_USE_METAL="${USE_METAL}" TVM_INFO_USE_MICRO_STANDALONE_RUNTIME="${USE_MICRO_STANDALONE_RUNTIME}" TVM_INFO_USE_MICRO="${USE_MICRO}" diff --git a/cmake/utils/FindLLVM.cmake b/cmake/utils/FindLLVM.cmake index f10e5f1eb8da..f40d97d9ba7f 100644 --- a/cmake/utils/FindLLVM.cmake +++ b/cmake/utils/FindLLVM.cmake @@ -150,6 +150,8 @@ macro(find_llvm use_llvm) list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRPresburger.a") list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRSupport.a") set(TVM_MLIR_VERSION ${TVM_LLVM_VERSION}) + message(STATUS "Build with MLIR") + message(STATUS "Set TVM_MLIR_VERSION=" ${TVM_MLIR_VERSION}) endif() endif() endif() diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index f1d86c861a59..3798ba190446 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -126,38 +126,54 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const { for (const IntegerRelation& disjunct : disjuncts) { PrimExpr union_entry = Bool(1); for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) { - PrimExpr linear_eq = IntImm(DataType::Int(32), 0); + PrimExpr linear_eq = IntImm(DataType::Int(64), 0); if (disjunct.getNumCols() > 1) { for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) { +#if TVM_MLIR_VERSION >= 160 + auto coeff = int64_t(disjunct.atEq(i, j)); +#else auto coeff = disjunct.atEq(i, j); +#endif if (coeff >= 0 || is_zero(linear_eq)) { - linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j]; + linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j]; } else { - linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j]; + linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * vars[j]; } } } +#if TVM_MLIR_VERSION >= 160 + auto c0 = int64_t(disjunct.atEq(i, disjunct.getNumCols() - 1)); +#else auto c0 = disjunct.atEq(i, disjunct.getNumCols() - 1); - linear_eq = linear_eq + IntImm(DataType::Int(32), c0); +#endif + linear_eq = linear_eq + IntImm(DataType::Int(64), c0); union_entry = (union_entry && (linear_eq == 0)); } for (unsigned i = 0, e = disjunct.getNumInequalities(); i < e; ++i) { - PrimExpr linear_eq = IntImm(DataType::Int(32), 0); + PrimExpr linear_eq = IntImm(DataType::Int(64), 0); if (disjunct.getNumCols() > 1) { for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) { +#if TVM_MLIR_VERSION >= 160 + auto coeff = int64_t(disjunct.atIneq(i, j)); +#else auto coeff = disjunct.atIneq(i, j); +#endif if (coeff >= 0 || is_zero(linear_eq)) { - linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j]; + linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j]; } else { - linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j]; + linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * vars[j]; } } } +#if TVM_MLIR_VERSION >= 160 + auto c0 = int64_t(disjunct.atIneq(i, disjunct.getNumCols() - 1)); +#else auto c0 = disjunct.atIneq(i, disjunct.getNumCols() - 1); +#endif if (c0 >= 0) { - linear_eq = linear_eq + IntImm(DataType::Int(32), c0); + linear_eq = linear_eq + IntImm(DataType::Int(64), c0); } else { - linear_eq = linear_eq - IntImm(DataType::Int(32), -c0); + linear_eq = linear_eq - IntImm(DataType::Int(64), -c0); } union_entry = (union_entry && (linear_eq >= 0)); } @@ -199,10 +215,19 @@ PresburgerSet Intersect(const Array& sets) { IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) { Array tvm_coeffs = DetectLinearEquation(e, set->GetVars()); +#if TVM_MLIR_VERSION >= 160 + SmallVector coeffs; +#else SmallVector coeffs; +#endif + coeffs.reserve(tvm_coeffs.size()); for (const PrimExpr& it : tvm_coeffs) { +#if TVM_MLIR_VERSION >= 160 + coeffs.push_back(mlir::presburger::MPInt(*as_const_int(it))); +#else coeffs.push_back(*as_const_int(it)); +#endif } IntSet result = IntSet().Nothing(); @@ -211,9 +236,17 @@ IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) { auto range = simplex.computeIntegerBounds(coeffs); auto maxRoundedDown(simplex.computeOptimum(Simplex::Direction::Up, coeffs)); auto opt = range.first.getOptimumIfBounded(); +#if TVM_MLIR_VERSION >= 160 + auto min = opt.has_value() ? IntImm(DataType::Int(64), int64_t(opt.value())) : neg_inf(); +#else auto min = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : neg_inf(); +#endif opt = range.second.getOptimumIfBounded(); +#if TVM_MLIR_VERSION >= 160 + auto max = opt.has_value() ? IntImm(DataType::Int(64), int64_t(opt.value())) : pos_inf(); +#else auto max = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : pos_inf(); +#endif auto interval = IntervalSet(min, max); result = Union({result, interval}); } diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 53f9292d162f..d94c74b5bb22 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -31,6 +31,10 @@ #define TVM_INFO_LLVM_VERSION "NOT-FOUND" #endif +#ifndef TVM_INFO_MLIR_VERSION +#define TVM_INFO_MLIR_VERSION "NOT-FOUND" +#endif + #ifndef TVM_INFO_USE_CUDA #define TVM_INFO_USE_CUDA "NOT-FOUND" #endif @@ -271,6 +275,7 @@ TVM_DLL Map GetLibInfo() { {"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64}, {"INSTALL_DEV", TVM_INFO_INSTALL_DEV}, {"LLVM_VERSION", TVM_INFO_LLVM_VERSION}, + {"MLIR_VERSION", TVM_INFO_MLIR_VERSION}, {"PICOJSON_PATH", TVM_INFO_PICOJSON_PATH}, {"RANG_PATH", TVM_INFO_RANG_PATH}, {"ROCM_PATH", TVM_INFO_ROCM_PATH}, @@ -311,6 +316,7 @@ TVM_DLL Map GetLibInfo() { {"USE_LIBBACKTRACE", TVM_INFO_USE_LIBBACKTRACE}, {"USE_LIBTORCH", TVM_INFO_USE_LIBTORCH}, {"USE_LLVM", TVM_INFO_USE_LLVM}, + {"USE_MLIR", TVM_INFO_USE_MLIR}, {"USE_METAL", TVM_INFO_USE_METAL}, {"USE_MICRO_STANDALONE_RUNTIME", TVM_INFO_USE_MICRO_STANDALONE_RUNTIME}, {"USE_MICRO", TVM_INFO_USE_MICRO},