Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] MLIR PresburgerSet compile fix mlir >= 160 #15638

Merged
merged 1 commit into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ tvm_option(USE_HEXAGON_EXTERNAL_LIBS "Path to git repo containing external Hexag
tvm_option(USE_RPC "Build with RPC" ON)
tvm_option(USE_THREADS "Build with thread support" ON)
tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" OFF)
tvm_option(USE_MLIR "Build with MLIR support" OFF)
tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF)
tvm_option(USE_GRAPH_EXECUTOR "Build with tiny graph executor" ON)
tvm_option(USE_GRAPH_EXECUTOR_CUDA_GRAPH "Build with tiny graph executor with CUDA Graph for GPUs" OFF)
Expand Down
7 changes: 7 additions & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ function(add_lib_info src_file)
else()
string(STRIP ${TVM_INFO_LLVM_VERSION} TVM_INFO_LLVM_VERSION)
endif()
if (NOT DEFINED TVM_INFO_MLIR_VERSION)
set(TVM_INFO_MLIR_VERSION "NOT-FOUND")
else()
string(STRIP ${TVM_INFO_MLIR_VERSION} TVM_INFO_MLIR_VERSION)
endif()
if (NOT DEFINED CUDA_VERSION)
set(TVM_INFO_CUDA_VERSION "NOT-FOUND")
else()
Expand All @@ -47,6 +52,7 @@ function(add_lib_info src_file)
TVM_INFO_INDEX_DEFAULT_I64="${INDEX_DEFAULT_I64}"
TVM_INFO_INSTALL_DEV="${INSTALL_DEV}"
TVM_INFO_LLVM_VERSION="${TVM_INFO_LLVM_VERSION}"
TVM_INFO_MLIR_VERSION="${TVM_INFO_MLIR_VERSION}"
TVM_INFO_PICOJSON_PATH="${PICOJSON_PATH}"
TVM_INFO_RANG_PATH="${RANG_PATH}"
TVM_INFO_ROCM_PATH="${ROCM_PATH}"
Expand Down Expand Up @@ -86,6 +92,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_LIBBACKTRACE="${USE_LIBBACKTRACE}"
TVM_INFO_USE_LIBTORCH="${USE_LIBTORCH}"
TVM_INFO_USE_LLVM="${USE_LLVM}"
TVM_INFO_USE_MLIR="${USE_MLIR}"
TVM_INFO_USE_METAL="${USE_METAL}"
TVM_INFO_USE_MICRO_STANDALONE_RUNTIME="${USE_MICRO_STANDALONE_RUNTIME}"
TVM_INFO_USE_MICRO="${USE_MICRO}"
Expand Down
2 changes: 2 additions & 0 deletions cmake/utils/FindLLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ macro(find_llvm use_llvm)
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRPresburger.a")
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRSupport.a")
set(TVM_MLIR_VERSION ${TVM_LLVM_VERSION})
message(STATUS "Build with MLIR")
message(STATUS "Set TVM_MLIR_VERSION=" ${TVM_MLIR_VERSION})
endif()
endif()
endif()
Expand Down
51 changes: 42 additions & 9 deletions src/arith/presburger_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,38 +126,54 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const {
for (const IntegerRelation& disjunct : disjuncts) {
PrimExpr union_entry = Bool(1);
for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) {
PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
PrimExpr linear_eq = IntImm(DataType::Int(64), 0);
if (disjunct.getNumCols() > 1) {
for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
#if TVM_MLIR_VERSION >= 160
auto coeff = int64_t(disjunct.atEq(i, j));
#else
auto coeff = disjunct.atEq(i, j);
#endif
if (coeff >= 0 || is_zero(linear_eq)) {
linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j];
} else {
linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j];
linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * vars[j];
}
}
}
#if TVM_MLIR_VERSION >= 160
auto c0 = int64_t(disjunct.atEq(i, disjunct.getNumCols() - 1));
#else
auto c0 = disjunct.atEq(i, disjunct.getNumCols() - 1);
linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
#endif
linear_eq = linear_eq + IntImm(DataType::Int(64), c0);
union_entry = (union_entry && (linear_eq == 0));
}
for (unsigned i = 0, e = disjunct.getNumInequalities(); i < e; ++i) {
PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
PrimExpr linear_eq = IntImm(DataType::Int(64), 0);
if (disjunct.getNumCols() > 1) {
for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
#if TVM_MLIR_VERSION >= 160
auto coeff = int64_t(disjunct.atIneq(i, j));
#else
auto coeff = disjunct.atIneq(i, j);
#endif
if (coeff >= 0 || is_zero(linear_eq)) {
linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j];
} else {
linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j];
linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * vars[j];
}
}
}
#if TVM_MLIR_VERSION >= 160
auto c0 = int64_t(disjunct.atIneq(i, disjunct.getNumCols() - 1));
#else
auto c0 = disjunct.atIneq(i, disjunct.getNumCols() - 1);
#endif
if (c0 >= 0) {
linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
linear_eq = linear_eq + IntImm(DataType::Int(64), c0);
} else {
linear_eq = linear_eq - IntImm(DataType::Int(32), -c0);
linear_eq = linear_eq - IntImm(DataType::Int(64), -c0);
}
union_entry = (union_entry && (linear_eq >= 0));
}
Expand Down Expand Up @@ -199,10 +215,19 @@ PresburgerSet Intersect(const Array<PresburgerSet>& sets) {

IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
Array<PrimExpr> tvm_coeffs = DetectLinearEquation(e, set->GetVars());
#if TVM_MLIR_VERSION >= 160
SmallVector<mlir::presburger::MPInt> coeffs;
#else
SmallVector<int64_t> coeffs;
#endif

coeffs.reserve(tvm_coeffs.size());
for (const PrimExpr& it : tvm_coeffs) {
#if TVM_MLIR_VERSION >= 160
coeffs.push_back(mlir::presburger::MPInt(*as_const_int(it)));
#else
coeffs.push_back(*as_const_int(it));
#endif
}

IntSet result = IntSet().Nothing();
Expand All @@ -211,9 +236,17 @@ IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
auto range = simplex.computeIntegerBounds(coeffs);
auto maxRoundedDown(simplex.computeOptimum(Simplex::Direction::Up, coeffs));
auto opt = range.first.getOptimumIfBounded();
#if TVM_MLIR_VERSION >= 160
auto min = opt.has_value() ? IntImm(DataType::Int(64), int64_t(opt.value())) : neg_inf();
#else
auto min = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : neg_inf();
#endif
opt = range.second.getOptimumIfBounded();
#if TVM_MLIR_VERSION >= 160
auto max = opt.has_value() ? IntImm(DataType::Int(64), int64_t(opt.value())) : pos_inf();
#else
auto max = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : pos_inf();
#endif
auto interval = IntervalSet(min, max);
result = Union({result, interval});
}
Expand Down
6 changes: 6 additions & 0 deletions src/support/libinfo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
#define TVM_INFO_LLVM_VERSION "NOT-FOUND"
#endif

#ifndef TVM_INFO_MLIR_VERSION
#define TVM_INFO_MLIR_VERSION "NOT-FOUND"
#endif

#ifndef TVM_INFO_USE_CUDA
#define TVM_INFO_USE_CUDA "NOT-FOUND"
#endif
Expand Down Expand Up @@ -271,6 +275,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
{"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64},
{"INSTALL_DEV", TVM_INFO_INSTALL_DEV},
{"LLVM_VERSION", TVM_INFO_LLVM_VERSION},
{"MLIR_VERSION", TVM_INFO_MLIR_VERSION},
{"PICOJSON_PATH", TVM_INFO_PICOJSON_PATH},
{"RANG_PATH", TVM_INFO_RANG_PATH},
{"ROCM_PATH", TVM_INFO_ROCM_PATH},
Expand Down Expand Up @@ -311,6 +316,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
{"USE_LIBBACKTRACE", TVM_INFO_USE_LIBBACKTRACE},
{"USE_LIBTORCH", TVM_INFO_USE_LIBTORCH},
{"USE_LLVM", TVM_INFO_USE_LLVM},
{"USE_MLIR", TVM_INFO_USE_MLIR},
{"USE_METAL", TVM_INFO_USE_METAL},
{"USE_MICRO_STANDALONE_RUNTIME", TVM_INFO_USE_MICRO_STANDALONE_RUNTIME},
{"USE_MICRO", TVM_INFO_USE_MICRO},
Expand Down