Skip to content

Commit

Permalink
[TVMScript] Improve printer for TIR syntax sugar (apache#9680)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuanjing Shi authored and baoxinqi committed Dec 27, 2021
1 parent fe5206e commit 45471eb
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 43 deletions.
53 changes: 26 additions & 27 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,7 @@ class BlockReads(SpecialStmt):

def __init__(self):
def reads(
read_regions: Union[BufferSlice, List[BufferSlice]],
*other_regions: BufferSlice,
*read_regions: Union[BufferSlice, List[BufferSlice]],
span: Span = None,
):
assert self.context, "call 'exit_scope' before 'enter_scope'"
Expand All @@ -335,16 +334,18 @@ def reads(
+ str(", ".join(str(x) for x in block_scope.reads)),
span,
)
if isinstance(read_regions, BufferSlice):
read_regions = [read_regions]
for region in other_regions:
read_regions.append(region)
if not isinstance(read_regions, list):
self.context.report_error(
"Incorrect input type. "
+ f"Expected BufferSlice or List[BufferSlice], but got {type(read_regions)}",
span,
)
if len(read_regions) > 1:
for read_region in read_regions:
if not isinstance(read_region, BufferSlice):
self.context.report_error(
"Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
+ f" but got {type(read_regions)}",
span,
)
elif len(read_regions) == 1:
if isinstance(read_regions[0], list):
read_regions = read_regions[0]

block_scope.reads = read_regions

super().__init__(reads, def_symbol=False)
Expand All @@ -368,8 +369,7 @@ class BlockWrites(SpecialStmt):

def __init__(self):
def writes(
write_region: Union[BufferSlice, List[BufferSlice]],
*other_region: BufferSlice,
*write_regions: Union[BufferSlice, List[BufferSlice]],
span: Span = None,
):
assert self.context, "call 'exit_scope' before 'enter_scope'"
Expand All @@ -386,19 +386,18 @@ def writes(
+ str(", ".join(str(x) for x in block_scope.writes)),
span,
)
if isinstance(write_region, list):
pass
elif isinstance(write_region, BufferSlice):
write_region = [write_region]
for region in other_region:
write_region.append(region)
else:
self.context.report_error(
"Incorrect input type. "
+ f"Expected BufferSlice or List[BufferSlice], but got {type(write_region)}",
span,
)
block_scope.writes = write_region
if len(write_regions) > 1:
for write_region in write_regions:
if not isinstance(write_region, BufferSlice):
self.context.report_error(
"Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
+ f" but got {type(write_regions)}",
span,
)
elif len(write_regions) == 1:
if isinstance(write_regions[0], list):
write_regions = write_regions[0]
block_scope.writes = write_regions

super().__init__(writes, def_symbol=False)

Expand Down
112 changes: 104 additions & 8 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintBlockVarRemaps();
Doc PrintBlockVars(const BlockRealizeNode* op);
Doc PrintBlockAttr(const BlockRealizeNode* op);
Doc PrintExpandedArray(const ArrayNode* op);
Doc PrintBlockBody(const BlockNode* op);
virtual Doc PrintBlockName(const BlockNode* block_op);
Doc PrintBufferRegion(const BufferRegionNode* op);
Expand All @@ -220,6 +221,13 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc AllocBuf(const Buffer& buffer);
void TryDeallocVar(const Var& var);
bool ContainsOptionalInfo(const Stmt& stmt);
/*!
* \brief check if a buffer declaration has only 'shape' and 'dtype' arguments specified
* \param buffer The match buffer to be checked
*/
bool IsSimpleBuffer(const Buffer& buffer);
Doc PrintInlineBufferBind(const Buffer& buffer);
Doc PrintTuple(const ArrayNode* op);

/*! Helper functions for loop printing. */
/*!
Expand Down Expand Up @@ -404,7 +412,7 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
if (buf->offset_factor != 1 || print_factor_explicitly) {
doc << ", offset_factor=" << buf->offset_factor;
}
if (buf->buffer_type != 1) {
if (buf->buffer_type != BufferType::kDefault) {
doc << ", type=" << Doc::StrLiteral("auto");
}
return doc;
Expand Down Expand Up @@ -471,6 +479,60 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
return doc;
}

// check if all arguments, except the first two, are specified for T.match_buffer
// if not, then this match buffer is printed out as T.buffer in prim_func arguments
bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) {
if (memo_var_.find(buf->data) != memo_var_.end()) {
return false;
}
if (!buf->strides.empty()) {
return false;
}
if (buf->elem_offset->IsInstance<VarNode>()) {
return false;
} else if (buf->elem_offset->IsInstance<IntImmNode>()) {
IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
if (elem_offset->value != 0) {
return false;
}
}
if (buf.scope() != "global") {
return false;
}
if (buf->data_alignment != runtime::kAllocAlignment) {
return false;
}
if (buf->offset_factor != 1) {
return false;
}
if (buf->buffer_type != BufferType::kDefault) {
return false;
}
return true;
}

Doc TVMScriptPrinter::PrintInlineBufferBind(const Buffer& buffer) {
Doc doc;
doc << tir_prefix_ << ".Buffer[" << PrintTuple(buffer->shape.as<ArrayNode>());
doc << ", " << PrintDType(buffer->dtype) << "]";
return doc;
}

// print array out as tuple with parentheses
Doc TVMScriptPrinter::PrintTuple(const ArrayNode* op) {
Doc doc;
doc << '(';
for (size_t i = 0; i < op->size(); ++i) {
if (i != 0) {
doc << ", ";
}
doc << Print(op->at(i));
}
if (op->size() == 1) doc << ",";
doc << ')';
return doc;
}

Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) {
Doc doc;
int n_var = static_cast<int>(op->rhs.size());
Expand Down Expand Up @@ -1095,8 +1157,10 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
if (!is_one(op->predicate)) {
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")";
}
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads(" << Print(block_op->reads) << ")";
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes(" << Print(block_op->writes) << ")";
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads("
<< PrintExpandedArray(block_op->reads.as<ArrayNode>()) << ")";
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes("
<< PrintExpandedArray(block_op->writes.as<ArrayNode>()) << ")";
if (!block_op->annotations.empty()) {
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".block_attr({";
block_attr_doc << PrintAnnotations(block_op->annotations);
Expand All @@ -1105,6 +1169,19 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
return block_attr_doc;
}

// This function is to make sure arguments of T.reads() and T.writes() is not parsed by printer as a
// List. Therefore the brackets are removed before and after printing arguments out
Doc TVMScriptPrinter::PrintExpandedArray(const ArrayNode* op) {
Doc doc;
for (size_t i = 0; i < op->size(); ++i) {
if (i != 0) {
doc << ", ";
}
doc << Print(op->at(i));
}
return doc;
}

Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) {
Doc body;
for (const auto& alloc_buf : op->alloc_buffers) {
Expand Down Expand Up @@ -1218,8 +1295,21 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
doc << "def " << (func2var_.find(op) == func2var_.end() ? "func" : func2var_[op]->name_hint)
<< "(";
std::vector<Doc> params;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> simple_buf;
for (const auto& param : op->params) {
var_not_in_headers_.insert(param.get());
auto it = op->buffer_map.find(param);
// check if this param is a T.handle
if (it != op->buffer_map.end()) {
// check if this match_buffer has only the first two arguments specified
const Buffer& buf = (*it).second;
if (IsSimpleBuffer(buf)) {
simple_buf.insert(buf);
buf_not_in_headers_.insert(buf.get());
params.push_back(Print(buf) << ": " << PrintInlineBufferBind(buf));
continue;
}
}
params.push_back(Print(param) << ": " << Print(GetType(param)));
}
doc << PrintSep(params, Doc::Text(", ")) << ") -> " << Print(primFunc->ret_type) << ":";
Expand All @@ -1229,9 +1319,11 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
for (const auto& param : op->params) {
auto it = op->buffer_map.find(param);
if (it == op->buffer_map.end()) continue;
buf_not_in_headers_.insert((*it).second.get());
body << Print((*it).second) << " = " << tir_prefix_ << ".match_buffer(";
body << Print((*it).first) << ", " << memo_buf_decl_[(*it).second];
const Buffer& buf = (*it).second;
if (simple_buf.count(buf)) continue;
buf_not_in_headers_.insert(buf.get());
body << Print(buf) << " = " << tir_prefix_ << ".match_buffer(";
body << Print((*it).first) << ", " << memo_buf_decl_[buf];
body << ")" << Doc::NewLine();
}
// print body
Expand Down Expand Up @@ -1392,8 +1484,12 @@ Doc TVMScriptPrinter::PrintAnnotations(const Map<String, ObjectRef>& annotations
Doc TVMScriptPrinter::PrintLoop(const For& loop) {
Doc res;
res << "for " << Print(loop->loop_var) << " in " << tir_prefix_
<< "." + std::string(ForKind2String(loop->kind)) + "(" << Print(loop->min) << ", "
<< Print(loop->min + loop->extent);
<< "." + std::string(ForKind2String(loop->kind)) + "(";
if (is_zero(loop->min)) {
res << Print(loop->extent);
} else {
res << Print(loop->min) << ", " << Print(loop->min + loop->extent);
}
if (loop->thread_binding.defined()) {
res << ", thread=";
res << Print(loop->thread_binding.value()->thread_tag);
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_tvmscript_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,10 +544,10 @@ def test_reorder_fail_nested_loop_inner():
with pytest.raises(tvm.tir.ScheduleError) as execinfo:
sch.reorder(k, i)
expected_sub_error_message = (
" for i in T.serial(0, 128):\n"
" for i in T.serial(128):\n"
" # tir.For#0\n"
" for j in T.serial(0, 128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
" for j in T.serial(128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^\n"
)
assert expected_sub_error_message in str(execinfo.value)

Expand All @@ -560,9 +560,9 @@ def test_fuse_fail_nested_loop_outer():
sch.fuse(k, i)
expected_sub_error_message = (
" # tir.For#1\n"
" for i in T.serial(0, 128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
" for j in T.serial(0, 128):\n"
" for i in T.serial(128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^\n"
" for j in T.serial(128):\n"
)
assert expected_sub_error_message in str(execinfo.value)

Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def elementwise_handle(
# match buffer - use buffer with kwargs
@T.prim_func
def elementwise_buffer_kwargs(
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32"),
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32"),
) -> None:
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
Expand Down

0 comments on commit 45471eb

Please sign in to comment.