Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TIR] Add DeclBuffer IR node and functors (apache#12300)
Browse files Browse the repository at this point in the history
* [TIR] Add DeclBuffer node

* [TIR] Add IR functors for DeclBuffer

* [TVMScript] Add printer and parser for DeclBuffer

* Update printer

* Update printer

* Add test case

* lint

* fix
  • Loading branch information
vinx13 authored and xinetzone committed Nov 25, 2022
1 parent f4c91ce commit 16b6287
Show file tree
Hide file tree
Showing 15 changed files with 241 additions and 2 deletions.
34 changes: 34 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,40 @@ class AllocateConst : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
};

/*! \brief Declare a buffer that can be used in the body */
class DeclBufferNode : public StmtNode {
public:
/*! \brief The buffer being declared */
Buffer buffer;
/*! \brief The body to be executed */
Stmt body;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("body", &body);
v->Visit("span", &span);
}

bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const {
return equal(buffer, other->buffer) && equal(body, other->body);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(body);
}

static constexpr const char* _type_key = "tir.DeclBuffer";
TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode);
};

/*! \brief Managed reference to DeclBufferNode */
class DeclBuffer : public Stmt {
public:
TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode);
};

/*!
* \brief The container of seq statement.
* Represent a sequence of statements.
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const DeclBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -116,6 +117,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode);
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
Expand Down Expand Up @@ -159,6 +161,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AllocateConstNode* op) override;
void VisitStmt_(const DeclBufferNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
Expand Down Expand Up @@ -260,6 +263,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const WhileNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const AllocateConstNode* op) override;
Stmt VisitStmt_(const DeclBufferNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/script/tir/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ def match_buffer(
buffer_type: str = "default",
axis_separators: Optional[List[int]] = None,
) -> Buffer: ...
def decl_buffer(
shape: Sequence[Union[PrimExpr, int]],
dtype: str = "float32",
data: Var = None,
strides: Optional[Sequence[int]] = None,
elem_offset: Optional[int] = None,
scope: str = "global",
align: int = -1,
offset_factor: int = 0,
buffer_type: str = "default",
axis_separators: Optional[List[int]] = None,
) -> Buffer: ...
def buffer_decl(
shape: Sequence[Union[PrimExpr, int]],
dtype: str = "float32",
Expand Down
81 changes: 81 additions & 0 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,87 @@ def setup_buffer(data, dtype, shape, annotations: dict = None, span: Span = None
context.update_symbol(name, self.buffer, node)


@register
class DeclBuffer(WithScopeHandler):
"""Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset, scope, align,
offset_factor, buffer_type, axis_separators)
Example
-------
.. code-block:: python
A = T.decl_buffer((128, 128), dtype="float32")
"""

def __init__(self):
def decl_buffer(
shape,
dtype="float32",
data=None,
strides=None,
elem_offset=None,
scope="global",
align=-1,
offset_factor=0,
buffer_type="default",
axis_separators=None,
span=None,
):
return tvm.tir.DeclBuffer(self.buffer, self.body, span=span)

super().__init__(decl_buffer, concise_scope=True, def_symbol=True)

def enter_scope(
self,
node: synr.ast.Node,
context: ContextMaintainer,
arg_list: List[Any],
span: synr.ast.Span,
):
# define buffer vars in symbol table
if isinstance(node, synr.ast.With):
vars = WithScopeHandler.get_optional_vars(node, context)
if len(vars) != 1:
context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
name = vars[0].id.name
var_span = vars[0].id.span
elif isinstance(node, synr.ast.Assign):
if len(node.lhs) != 1:
context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
name = node.lhs[0].id.name
var_span = node.lhs[0].id.span
else:
raise Exception("Internal Bug")

def setup_buffer(
shape,
dtype,
data,
strides,
elem_offset,
scope,
align,
offset_factor,
buffer_type,
axis_separators,
span: Span = None,
):
self.buffer = tvm.tir.decl_buffer(
shape=shape,
dtype=dtype,
data=data,
strides=strides,
elem_offset=elem_offset,
scope=scope,
data_alignment=align,
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators,
span=span,
)

setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer, node)


@register
class LaunchThread(WithScopeHandler):
"""With scope handler T.launch_thread(env_var, extent)"""
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Allocate,
AllocateConst,
AttrStmt,
DeclBuffer,
)

from .stmt import ProducerRealize, SeqStmt
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,26 @@ def __init__(self, buffer_var, dtype, extents, data_or_idx, body, annotations=No
)


@tvm._ffi.register_object("tir.DeclBuffer")
class DeclBuffer(Stmt):
"""DeclBuffer node.
Parameters
----------
buffer: Buffer
The buffer being declared.
body: Stmt
The body statement to be executed.
span: Optional[Span]
The location of this DeclBuffer in the source code.
"""

def __init__(self, buffer, body, span=None):
self.__init_handle_by_constructor__(_ffi_api.DeclBuffer, buffer, body, span)


@tvm._ffi.register_object("tir.AttrStmt")
class AttrStmt(Stmt):
"""AttrStmt node.
Expand Down
1 change: 1 addition & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const ProducerRealizeNode* op) override;
Doc VisitStmt_(const AllocateNode* op) override;
Doc VisitStmt_(const AllocateConstNode* op) override;
Doc VisitStmt_(const DeclBufferNode* op) override;
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const EvaluateNode* op) override;
Expand Down
12 changes: 12 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,18 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) {
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) {
Doc doc;
doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data) << ", "
<< PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << Doc::NewLine();
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
}
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << PrintBody(op->then_case);
Expand Down
19 changes: 19 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const BufferRealizeNode* op) override;
Doc VisitStmt_(const AllocateNode* op) override;
Doc VisitStmt_(const AllocateConstNode* op) override;
Doc VisitStmt_(const DeclBufferNode* op) override;
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const ForNode* op) override;
Expand Down Expand Up @@ -1161,6 +1162,24 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
return doc;
}

Doc TVMScriptPrinter::VisitStmt_(const DeclBufferNode* op) {
const Buffer& buffer = op->buffer;
buf_not_in_headers_.insert(buffer.get());
Doc buffer_name = Print(op->buffer);
Doc func_call;
func_call << tir_prefix_ << ".decl_buffer(" << memo_buf_decl_.at(buffer) << ")";

Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with " << func_call << " as " << buffer_name << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
} else {
doc << buffer_name << " = " << func_call << Doc::NewLine();
doc << PrintBody(op->body);
}
return doc;
}

Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << ":";
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
this->PrintStmt(op->body);
}

void CodeGenC::VisitStmt_(const DeclBufferNode* op) { this->PrintStmt(op->body); }

void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead.";
}
Expand Down
1 change: 1 addition & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const AllocateConstNode* op) override;
void VisitStmt_(const DeclBufferNode* op) override;

/*!
* \brief Print expr representing the thread tag
Expand Down
23 changes: 23 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,29 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->Print(op->body);
});

// DeclBuffer
DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) {
ObjectPtr<DeclBufferNode> node = make_object<DeclBufferNode>();
node->buffer = std::move(buffer);
node->body = std::move(body);
node->span = std::move(span);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt body, Span span) {
return DeclBuffer(buffer, body, span);
});

TVM_REGISTER_NODE_TYPE(DeclBufferNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DeclBufferNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DeclBufferNode*>(node.get());
p->PrintIndent();
p->stream << "decl_buffer " << op->buffer << "\n";
p->stream << op->body;
});

// ProducerRealize
ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition,
Stmt body, String storage_scope, Span span) {
Expand Down
14 changes: 14 additions & 0 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ void StmtVisitor::VisitStmt_(const AllocateConstNode* op) {
this->VisitStmt(op->body);
}

void StmtVisitor::VisitStmt_(const DeclBufferNode* op) { this->VisitStmt(op->body); }

void StmtVisitor::VisitStmt_(const StoreNode* op) {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
}
Expand Down Expand Up @@ -336,6 +338,18 @@ Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) {
}
}

Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) {
Stmt body = this->VisitStmt(op->body);

if (body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->body = std::move(body);
return Stmt(n);
}
}

Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
Stmt then_case = this->VisitStmt(op->then_case);
Expand Down
6 changes: 4 additions & 2 deletions tests/cpp/ir_functor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ TEST(IRF, StmtVisitor) {
DataType dtype = DataType::Float(32);
Var buf_var("b", PointerType(PrimType(dtype)));
Buffer buffer = decl_buffer({16});
body = DeclBuffer(buffer, std::move(body));
BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)});
MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);

Expand Down Expand Up @@ -309,6 +310,7 @@ TEST(IRF, StmtMutator) {
DataType dtype = DataType::Float(32);
Var buf_var("b", PointerType(PrimType(dtype)));
Buffer buffer = decl_buffer({16});
body = DeclBuffer(buffer, std::move(body));
BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)});
MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);
// construct block and block_realize
Expand All @@ -318,8 +320,8 @@ TEST(IRF, StmtMutator) {
body = v(std::move(block_realize));
// the body should be changed
Block new_block = body.as<BlockRealizeNode>()->block;
ICHECK(new_block->body.as<AllocateNode>()->extents[1].same_as(x));
ICHECK(new_block->init.as<AllocateNode>()->extents[1].same_as(x));
ICHECK(new_block->body.as<DeclBufferNode>()->body.as<AllocateNode>()->extents[1].same_as(x));
ICHECK(new_block->init.as<DeclBufferNode>()->body.as<AllocateNode>()->extents[1].same_as(x));
ICHECK(new_block->reads[0]->region[0]->min.same_as(x));
ICHECK(new_block->writes[0]->region[0]->min.same_as(x));
ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x));
Expand Down
Loading

0 comments on commit 16b6287

Please sign in to comment.