Skip to content

Commit

Permalink
rename linear_system to int_constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Apr 5, 2020
1 parent a9a29b5 commit c2678f8
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

/*!
* \file tvm/arith/linear_system.h
* \brief Linear system data structures and solvers
* \brief integer constraints data structures and solvers
*/
#ifndef TVM_ARITH_LINEAR_SYSTEM_H_
#define TVM_ARITH_LINEAR_SYSTEM_H_
#ifndef TVM_ARITH_INT_SOLVER_H_
#define TVM_ARITH_INT_SOLVER_H_

#include <tvm/ir/expr.h>
#include <tvm/tir/expr.h>
Expand All @@ -37,13 +37,13 @@ using tir::VarNode;
using tir::IterVar;

/*!
* \brief Represent a linear system including variables, their ranges and
* \brief Represent integer constrains including (integer) variables, their ranges and
* the relations between them (either equations or inequalities).
* \sa LinearSystem
*/
class LinearSystemNode : public Object {
class IntConstraintsNode : public Object {
public:
// e.g., \alpha, \beta
// e.g., \alpha, \beta, must be integers
Array<Var> variables;
// e.g., 1 <= \alpha <= N, etc.
Map<Var, Range> ranges;
Expand All @@ -57,32 +57,32 @@ class LinearSystemNode : public Object {
v->Visit("relations", &relations);
}

static constexpr const char* _type_key = "arith.LinearSystem";
TVM_DECLARE_FINAL_OBJECT_INFO(LinearSystemNode, Object);
static constexpr const char* _type_key = "arith.IntConstraints";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object);
};

/*!
* \brief Managed reference to LinearSystemNode.
* \sa LinearSystemNode
* \brief Managed reference to IntConstraintsNode.
* \sa IntConstraintsNode
*/
class LinearSystem : public ObjectRef {
class IntConstraints : public ObjectRef {
public:
/*!
* \brief Constructor by fields
* \param variables The variables in the system.
* \param variables The variables in the constraints, must be integers.
* \param ranges The ranges of the variables.
* \param relations The linear relations between the variables
* (either equations or inequalities)
*/
TVM_DLL LinearSystem(Array<Var> variables,
Map<Var, Range> ranges,
Array<PrimExpr> relations);
TVM_DLL IntConstraints(Array<Var> variables,
Map<Var, Range> ranges,
Array<PrimExpr> relations);

TVM_DEFINE_OBJECT_REF_METHODS(LinearSystem, ObjectRef, LinearSystemNode);
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode);
};

/*!
* \brief We can have different set of variables to represent the same linear system.
* \brief We can have different set of variables to represent the same constraints.
* For example, the following two systems are equivalent,
* {a + b = 0 | a >= 0, b >= 0} and
* {m - n = 0 | m >= 0, n <= 0}
Expand All @@ -93,12 +93,12 @@ class LinearSystem : public ObjectRef {
* dst : {m - n = 0 | m >= 0, n <= 0}
* src_to_dst : {a -> m, b -> -n}
* dst_to_src : {m -> a, n -> -b}
* \sa LinearSystemTransform
* \sa IntConstraintsTransform
*/
class LinearSystemTransformNode : public Object {
class IntConstraintsTransformNode : public Object {
public:
LinearSystem src;
LinearSystem dst;
IntConstraints src;
IntConstraints dst;
Map<Var, PrimExpr> src_to_dst;
Map<Var, PrimExpr> dst_to_src;

Expand All @@ -109,31 +109,32 @@ class LinearSystemTransformNode : public Object {
v->Visit("dst_to_src", &dst_to_src);
}

static constexpr const char* _type_key = "arith.LinearSystemTransform";
TVM_DECLARE_FINAL_OBJECT_INFO(LinearSystemTransformNode, Object);
static constexpr const char* _type_key = "arith.IntConstraintsTransform";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object);
};

/*!
* \brief Managed reference to LinearSystemTransformNode.
* \sa LinearSystemTransformNode
* \brief Managed reference to IntConstraintsTransformNode.
* \sa IntConstraintsTransformNode
*/
class LinearSystemTransform : public ObjectRef {
class IntConstraintsTransform : public ObjectRef {
public:
/*!
* \brief Constructor by fields
* \param src source linear system, e.g., {a + b = 0 | a >= 0, b >= 0}
* \param dst linear system equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0}
* \param src source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
* \param dst integer constraints equivalent to the source,
* e.g., {m - n = 0 | m >= 0, n <= 0}
* \param src_to_dst mapping from variables in the \p src to the variables in the \p dst,
* e.g., {a -> m, b -> -n}
* \param dst_to_src mapping from variables in the \p dst to the variables in the \p src,
* e.g., {m -> a, n -> -b}
*/
TVM_DLL LinearSystemTransform(LinearSystem src,
LinearSystem dst,
Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src);
TVM_DLL IntConstraintsTransform(IntConstraints src,
IntConstraints dst,
Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src);

TVM_DEFINE_OBJECT_REF_METHODS(LinearSystemTransform, ObjectRef, LinearSystemTransformNode);
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
};

/*!
Expand Down Expand Up @@ -165,8 +166,8 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t>> *S,
* as well as inequalities inferred from the \p system_to_solve.
* You can get the mapping from the original variables to the solution via ret->src_to_dst.
*/
LinearSystemTransform SolveEquations(const LinearSystem& system_to_solve);
IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_LINEAR_SYSTEM_H_
#endif // TVM_ARITH_INT_SOLVER_H_
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
from .linear_system import solve_equations
from .int_solver import solve_linear_equations
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,39 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Linear system data structures and solvers"""
"""integer constraints data structures and solvers"""
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api


@tvm._ffi.register_object("arith.LinearSystem")
class LinearSystem(Object):
"""Represent a linear system including variables, their ranges and
the linear relations between them (either equations or inequalities)
@tvm._ffi.register_object("arith.IntConstraints")
class IntConstraints(Object):
"""Represent a set of integer constraints including variables, their ranges and
the relations between them (either equations or inequalities)
Parameters
----------
variables : List[tvm.tir.Var]
The variables in the system.
The variables in the constraints. Must be integers
ranges : Map[tvm.tir.Var, tvm.ir.Range]
The ranges of the variables.
relations : List[tvm.ir.PrimExpr]
The linear relations between the variables (either equations or inequalities)
The relations between the variables (either equations or inequalities)
"""
def __init__(self, variables, ranges, relations):
self.__init_handle_by_constructor__(
_ffi_api.LinearSystem, variables, ranges, relations)
_ffi_api.IntConstraints, variables, ranges, relations)


@tvm._ffi.register_object("arith.LinearSystemTransform")
class LinearSystemTransform(Object):
"""We can have different set of variables to represent the same linear system.
For example, the following two systems are equivalent,
@tvm._ffi.register_object("arith.IntConstraintsTransform")
class IntConstraintsTransform(Object):
"""We can have different set of variables to represent the same integer constraints.
For example, the following two constrains are equivalent,
{a + b = 0 | a >= 0, b >= 0} and
{m - n = 0 | m >= 0, n <= 0}
This data structure represents the transformation
between two equivalent linear systems.
between two equivalent integer constraints.
In the above example,
src : {a + b = 0 | a >= 0, b >= 0}
dst : {m - n = 0 | m >= 0, n <= 0}
Expand All @@ -55,10 +55,10 @@ class LinearSystemTransform(Object):
Parameters
----------
src : arith.LinearSystem
source linear system, e.g., {a + b = 0 | a >= 0, b >= 0}
dst : arith.LinearSystem
linear system equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0}
src : arith.IntConstraints
source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
dst : arith.IntConstraints
integer constraints equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0}
src_to_dst : Map[tvm.tir.Var, tvm.ir.PrimExpr]
mapping from variables in the src to the variables in the dst,
e.g., {a -> m, b -> -n}
Expand All @@ -68,15 +68,15 @@ class LinearSystemTransform(Object):
"""
def __init__(self, src, dst, src_to_dst, dst_to_src):
self.__init_handle_by_constructor__(
_ffi_api.LinearSystemTransform, src, dst, src_to_dst, dst_to_src)
_ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src)


def solve_equations(equations, variables, ranges):
def solve_linear_equations(equations, variables, ranges):
"""Solve linear equations.
Parameters
----------
equations: List[tvm.ir.PrimExpr] or LinearSystem
equations: List[tvm.ir.PrimExpr] or IntConstraints
The equations of the variables
variables : List[tvm.tir.Var]
The variables in the system.
Expand All @@ -85,15 +85,15 @@ def solve_equations(equations, variables, ranges):
Returns
-------
linear_system_transform : LinearSystemTransform
A new linear system, with less variables (if the problem is NOT of full rank),
int_constraints_transform : IntConstraintsTransform
New integer constraints, with less variables (if the problem is NOT of full rank),
or no variable (if the problem is of full rank),
or an empty linear system (if the problem is unsolvable).
or an empty integer constraints (if the problem is unsolvable).
It also provides the ranges of the variables in the new system,
as well as inequalities inferred from the problem.
You can get the mapping from the original variables to the solution via
linear_system_transform.src_to_dst.
int_constraints_transform.src_to_dst.
"""
if isinstance(equations, LinearSystem):
return _ffi_api.SolveEquations(equations)
return _ffi_api.SolveEquations(variables, ranges, equations)
if isinstance(equations, IntConstraints):
return _ffi_api.SolveLinearEquations(equations)
return _ffi_api.SolveLinearEquations(variables, ranges, equations)
44 changes: 24 additions & 20 deletions src/arith/linear_system.cc → src/arith/int_constraints.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
*/

/*!
* \file linear_system.cc
* \brief The linear system data structures.
* \file int_constraints.cc
* \brief The integer constraints data structures.
*/
#include <tvm/arith/linear_system.h>
#include <tvm/arith/int_solver.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/runtime/registry.h>
Expand All @@ -33,47 +33,51 @@
namespace tvm {
namespace arith {

LinearSystem::LinearSystem(Array<Var> variables,
Map<Var, Range> ranges,
Array<PrimExpr> relations) {
ObjectPtr<LinearSystemNode> node = make_object<LinearSystemNode>();
IntConstraints::IntConstraints(Array<Var> variables,
Map<Var, Range> ranges,
Array<PrimExpr> relations) {
ObjectPtr<IntConstraintsNode> node = make_object<IntConstraintsNode>();
for (const auto& var : variables) {
CHECK(var.dtype().is_int() || var.dtype().is_uint())
<< "Variables in IntConstraints must be integers";
}
node->variables = std::move(variables);
node->ranges = std::move(ranges);
node->relations = std::move(relations);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(LinearSystemNode);
TVM_REGISTER_NODE_TYPE(IntConstraintsNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LinearSystemNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LinearSystemNode*>(node.get());
p->stream << "LinearSystem("
.set_dispatch<IntConstraintsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntConstraintsNode*>(node.get());
p->stream << "IntConstraints("
<< op->variables
<< ", " << op->ranges
<< ", " << op->relations
<< ")";
});


LinearSystemTransform::LinearSystemTransform(LinearSystem src,
LinearSystem dst,
Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src) {
ObjectPtr<LinearSystemTransformNode> node = make_object<LinearSystemTransformNode>();
IntConstraintsTransform::IntConstraintsTransform(IntConstraints src,
IntConstraints dst,
Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src) {
ObjectPtr<IntConstraintsTransformNode> node = make_object<IntConstraintsTransformNode>();
node->src = std::move(src);
node->dst = std::move(dst);
node->src_to_dst = std::move(src_to_dst);
node->dst_to_src = std::move(dst_to_src);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(LinearSystemTransformNode);
TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LinearSystemTransformNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LinearSystemTransformNode*>(node.get());
p->stream << "LinearSystemTransform("
.set_dispatch<IntConstraintsTransformNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntConstraintsTransformNode*>(node.get());
p->stream << "IntConstraintsTransform("
<< "\n\t" << op->src
<< "\n\t" << op->dst
<< "\n\t" << op->src_to_dst
Expand Down
Loading

0 comments on commit c2678f8

Please sign in to comment.