Skip to content

Commit

Permalink
reads writes done
Browse files Browse the repository at this point in the history
  • Loading branch information
shingjan committed Dec 8, 2021
1 parent 3c05eb6 commit e47a560
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintBlockVarRemaps();
Doc PrintBlockVars(const BlockRealizeNode* op);
Doc PrintBlockAttr(const BlockRealizeNode* op);
Doc PrintBlockAttrArray(const ArrayNode* op);
Doc PrintBlockBody(const BlockNode* op);
virtual Doc PrintBlockName(const BlockNode* block_op);
Doc PrintBufferRegion(const BufferRegionNode* op);
Expand Down Expand Up @@ -1095,8 +1096,10 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
if (!is_one(op->predicate)) {
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")";
}
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads(" << Print(block_op->reads) << ")";
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes(" << Print(block_op->writes) << ")";
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads("
<< PrintBlockAttrArray(block_op->reads.as<ArrayNode>()) << ")";
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes("
<< PrintBlockAttrArray(block_op->writes.as<ArrayNode>()) << ")";
if (!block_op->annotations.empty()) {
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".block_attr({";
block_attr_doc << PrintAnnotations(block_op->annotations);
Expand All @@ -1105,6 +1108,19 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
return block_attr_doc;
}

// This function is to make sure arguments of T.reads() and T.writes() is not parsed by printer as a
// List. Therefore the brackets are removed before and after printing arguments out
Doc TVMScriptPrinter::PrintBlockAttrArray(const ArrayNode* op) {
Doc doc;
for (size_t i = 0; i < op->size(); ++i) {
if (i != 0) {
doc << ", ";
}
doc << Print(op->at(i));
}
return doc;
}

Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) {
Doc body;
for (const auto& alloc_buf : op->alloc_buffers) {
Expand Down

0 comments on commit e47a560

Please sign in to comment.