Skip to content

Commit

Permalink
[TensorIR] change IntRV to ExprRV (#8077)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored May 20, 2021
1 parent 7c732af commit 1203d73
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 46 deletions.
20 changes: 10 additions & 10 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ class LoopRV : public runtime::ObjectRef {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopRV, runtime::ObjectRef, LoopRVNode);
};

/**************** Random variable: IntRV ****************/
/**************** Random variable: ExprRV ****************/

/*! \brief An integer random variable */
using IntRV = PrimExpr;
/*! \brief An expr random variable */
using ExprRV = PrimExpr;

using IntRVNode = PrimExprNode;
using ExprRVNode = PrimExprNode;

/**************** The Schedule class ****************/

Expand Down Expand Up @@ -124,11 +124,11 @@ class ScheduleNode : public runtime::Object {
*/
virtual For Get(const LoopRV& loop_rv) const = 0;
/*!
* \brief Get the value corresponding to the specific random variable
* \param int_rv The random variable to be looked up
* \return The corresponding value
* \brief Get the expr corresponding to the specific random variable
* \param expr_rv The random variable to be looked up
* \return The corresponding expr
*/
virtual int64_t Get(const IntRV& int_rv) const = 0;
virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
/*!
* \brief Get the block sref corresponding to the specific BlockRV
* \param block_rv The BlockRV to be looked up
Expand Down Expand Up @@ -165,9 +165,9 @@ class ScheduleNode : public runtime::Object {
virtual void RemoveRV(const LoopRV& loop_rv) = 0;
/*!
* \brief Remove an integer random variable from the symbol table
* \param int_rv The random variable to be removed
* \param expr_rv The random variable to be removed
*/
virtual void RemoveRV(const IntRV& int_rv) = 0;
virtual void RemoveRV(const ExprRV& expr_rv) = 0;

public:
/******** Block/Loop relation ********/
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@

from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
from .state import ScheduleDebugMask, ScheduleState
from .schedule import LoopRV, BlockRV, IntRV, RAND_VAR_TYPE, Schedule
from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule
12 changes: 6 additions & 6 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class BlockRV(Object):
"""A random variable that refers to a block"""


IntRV = PrimExpr # A random variable that evaluates to an integer
ExprRV = PrimExpr # A random variable that evaluates to an integer

RAND_VAR_TYPE = Union[IntRV, BlockRV, LoopRV] # pylint: disable=invalid-name
RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name


@_register_object("tir.Schedule")
Expand Down Expand Up @@ -132,7 +132,7 @@ def show(self, rand_var: RAND_VAR_TYPE) -> str:
"""Returns a string representation of the value that the random variable evaluates to
Parameters
----------
rand_var : Union[IntRV, BlockRV, LoopRV]
rand_var : Union[ExprRV, BlockRV, LoopRV]
The random variable to be evaluated
Returns
----------
Expand All @@ -150,12 +150,12 @@ def get(
"""Returns:
- the corresponding Block that a BlockRV evaluates to;
- the corresponding For that a LoopRV evaluates to;
- the corresponding integer that a IntRV evaluates to;
- the corresponding integer that a ExprRV evaluates to;
- the corresponding Block that a block sref points to;
- the corresponding For that a loop sref points to;
Parameters
----------
rand_var_or_sref : Union[IntRV, BlockRV, LoopRV, StmtSRef]
rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, StmtSRef]
The random variable / sref to be evaluated
Returns
----------
Expand Down Expand Up @@ -192,7 +192,7 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None:
"""Remove a random variable from the symbol table
Parameters
----------
rand_var : Union[BlockRV, LoopRV, IntRV]
rand_var : Union[BlockRV, LoopRV, ExprRV]
The random variable to be removed
"""
return _ffi_api_schedule.ScheduleRemoveRV(self, rand_var) # pylint: disable=no-member
Expand Down
48 changes: 24 additions & 24 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Lookup random variables ********/
inline Block Get(const BlockRV& block_rv) const final;
inline For Get(const LoopRV& loop_rv) const final;
inline int64_t Get(const IntRV& int_rv) const final;
inline PrimExpr Get(const ExprRV& expr_rv) const final;
inline StmtSRef GetSRef(const BlockRV& block_rv) const final;
inline StmtSRef GetSRef(const LoopRV& loop_rv) const final;
void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); }
void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); }
void RemoveRV(const IntRV& int_rv) final { RemoveFromSymbolTable(int_rv); }
void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); }
using ScheduleNode::GetSRef;

public:
Expand Down Expand Up @@ -103,17 +103,17 @@ class ConcreteScheduleNode : public ScheduleNode {
template <class T>
inline T CreateRV(const StmtSRef& sref);
/*!
* \brief Add an integer as a random variable into the symbol table
* \param number The integer to be added to the symbol table
* \brief Add an expr as a random variable into the symbol table
* \param expr The expr to be added to the symbol table
* \return The new random variable created
*/
inline IntRV CreateRV(int64_t number);
inline ExprRV CreateRV(const PrimExpr& expr);
/*!
* \brief Add integers as random variables into the symbol table
* \param numbers The integers to be added to the symbol table
* \brief Add expr as random variables into the symbol table
* \param exprs The expr to be added to the symbol table
* \return The new random variables created
*/
inline Array<IntRV> CreateRV(const Array<Integer>& numbers);
inline Array<ExprRV> CreateRV(const Array<PrimExpr>& exprs);
/*! \brief Remove a random variable from the symbol table */
inline void RemoveFromSymbolTable(const ObjectRef& rv);
};
Expand All @@ -134,18 +134,18 @@ inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const {
return GetRef<For>(loop);
}

inline int64_t ConcreteScheduleNode::Get(const IntRV& int_rv) const {
auto it = this->symbol_table_.find(int_rv);
inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const {
auto it = this->symbol_table_.find(expr_rv);
if (it == this->symbol_table_.end()) {
LOG(FATAL) << "IndexError: Cannot find corresponding IntRV: " << int_rv;
LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << expr_rv;
}
const ObjectRef& obj = (*it).second;
const auto* int_imm = obj.as<IntImmNode>();
if (int_imm == nullptr) {
LOG(FATAL) << "ValueError: IntRV's corresponding type is invalid: "
const auto* expr_node = obj.as<PrimExprNode>();
if (expr_node == nullptr) {
LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: "
<< (obj.defined() ? obj->GetTypeKey() : "None");
}
return int_imm->value;
return GetRef<PrimExpr>(expr_node);
}

inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const {
Expand Down Expand Up @@ -211,18 +211,18 @@ inline T ConcreteScheduleNode::CreateRV(const StmtSRef& sref) {
return rv;
}

inline IntRV ConcreteScheduleNode::CreateRV(int64_t number) {
Var rv;
this->symbol_table_.Set(rv, Integer(number));
inline ExprRV ConcreteScheduleNode::CreateRV(const PrimExpr& expr) {
ExprRV rv;
this->symbol_table_.Set(rv, expr);
return std::move(rv);
}

inline Array<IntRV> ConcreteScheduleNode::CreateRV(const Array<Integer>& numbers) {
Array<IntRV> result;
result.reserve(numbers.size());
for (int64_t number : numbers) {
Var rv;
this->symbol_table_.Set(rv, IntImm(DataType::Int(32), number));
inline Array<ExprRV> ConcreteScheduleNode::CreateRV(const Array<PrimExpr>& exprs) {
Array<ExprRV> result;
result.reserve(exprs.size());
for (const PrimExpr& expr : exprs) {
ExprRV rv;
this->symbol_table_.Set(rv, expr);
result.push_back(rv);
}
return result;
Expand Down
9 changes: 4 additions & 5 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet")
if (const auto* block_rv = obj.as<BlockRVNode>()) {
return self->Get(GetRef<BlockRV>(block_rv));
}
if (const auto* int_rv = obj.as<IntRVNode>()) {
int64_t result = self->Get(GetRef<IntRV>(int_rv));
return IntImm(DataType::Int(32), result);
if (const auto* expr_rv = obj.as<ExprRVNode>()) {
return self->Get(GetRef<ExprRV>(expr_rv));
}
LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << obj->GetTypeKey()
<< ". Its value is: " << obj;
Expand Down Expand Up @@ -109,8 +108,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV")
if (const auto* block_rv = obj.as<BlockRVNode>()) {
return self->RemoveRV(GetRef<BlockRV>(block_rv));
}
if (const auto* int_rv = obj.as<IntRVNode>()) {
return self->RemoveRV(GetRef<IntRV>(int_rv));
if (const auto* expr_rv = obj.as<ExprRVNode>()) {
return self->RemoveRV(GetRef<ExprRV>(expr_rv));
}
LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey();
throw;
Expand Down

0 comments on commit 1203d73

Please sign in to comment.