-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Improve well-formed check's handling of match buffer #16655
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,47 +78,22 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) { | |
// variable has occurred. Therefore, to ensure that we only avoid | ||
// duplicate calls to VisitVarDef, these semantics need to be | ||
// checked. | ||
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_params; | ||
std::vector<std::variant<DefContext<Var>, DefContext<Buffer>>> context; | ||
|
||
auto ppath = path->Attr("params"); | ||
for (size_t i = 0; i < func->params.size(); i++) { | ||
context.push_back(WithDef(func->params[i], ppath->ArrayIndex(i))); | ||
defined_params.insert(func->params[i]); | ||
} | ||
|
||
auto try_visit_implicit_var_def = [this, &defined_params, &context](const PrimExpr& expr, | ||
ObjectPath path) { | ||
if (auto opt = expr.as<Var>()) { | ||
auto var = opt.value(); | ||
if (!defined_params.count(var)) { | ||
context.push_back(WithDef(var, path)); | ||
defined_params.insert(var); | ||
} | ||
} | ||
}; | ||
auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](const Array<PrimExpr>& arr, | ||
ObjectPath path) { | ||
for (size_t i = 0; i < arr.size(); i++) { | ||
try_visit_implicit_var_def(arr[i], path->ArrayIndex(i)); | ||
} | ||
}; | ||
|
||
auto buffer_map_path = path->Attr("buffer_map"); | ||
for (size_t i = 0; i < func->params.size(); i++) { | ||
if (auto opt = func->buffer_map.Get(func->params[i])) { | ||
auto buf = opt.value(); | ||
auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i)); | ||
|
||
// A buffer in the buffer_map always defines its data pointer | ||
context.push_back(WithDef(buf->data, buf_path->Attr("data"))); | ||
|
||
// But other implicit definitions only apply if they weren't | ||
// provided as explicit parameters, and they weren't defined | ||
// implicitly by any previous buffer. | ||
try_visit_implicit_var_def_array(buf->shape, buf_path->Attr("shape")); | ||
try_visit_implicit_var_def_array(buf->strides, buf_path->Attr("strides")); | ||
try_visit_implicit_var_def(buf->elem_offset, buf_path->Attr("elem_offset")); | ||
for (auto& def : WithMatchBufferDefs(buf, buf_path)) { | ||
context.push_back(std::move(def)); | ||
} | ||
} | ||
} | ||
|
||
|
@@ -127,7 +102,7 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) { | |
for (size_t i = 0; i < func->params.size(); i++) { | ||
if (auto opt = func->buffer_map.Get(func->params[i])) { | ||
auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i)); | ||
EnterDef(opt.value(), buf_path); | ||
context.push_back(WithDef(opt.value(), buf_path)); | ||
} | ||
} | ||
|
||
|
@@ -199,12 +174,23 @@ void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, ObjectPath path) { | |
void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) { | ||
Visit(op->value, path->Attr("value")); | ||
|
||
std::optional<DefContext<IterVar>> context = std::nullopt; | ||
std::vector<std::variant<DefContext<IterVar>, DefContext<Var>>> context; | ||
if (auto iter_var = op->node.as<IterVar>(); | ||
iter_var && (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread)) { | ||
// Some attributes serve as a source of definition for the | ||
// tir::Var they annotate. | ||
context = WithDef(iter_var.value(), path->Attr("node")); | ||
context.push_back(WithDef(iter_var.value(), path->Attr("node"))); | ||
} else if (op->attr_key == attr::buffer_bind_scope) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably worth commenting that this acts as an older form of MatchBuffer, per the PR description. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you, and added a comment with description. |
||
Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node); | ||
ICHECK_EQ(arr.size(), 2U); | ||
Buffer buffer_view = Downcast<Buffer>(arr[0]); | ||
Buffer orig_buffer = Downcast<Buffer>(arr[1]); | ||
Visit(orig_buffer, path->Attr("node")->ArrayIndex(1)); | ||
|
||
for (auto& var : WithMatchBufferDefs(buffer_view, path->Attr("node")->ArrayIndex(0))) { | ||
context.push_back(std::move(var)); | ||
} | ||
|
||
} else if (auto expr = op->node.as<PrimExpr>()) { | ||
Visit(expr.value(), path->Attr("node")); | ||
} | ||
|
@@ -250,7 +236,8 @@ void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, ObjectPath path) | |
void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, ObjectPath path) { | ||
Visit(op->condition, path->Attr("condition")); | ||
Visit(op->bounds, path->Attr("bounds")); | ||
auto context = WithDef(op->buffer, path->Attr("buffer")); | ||
auto context = WithDefIfUndefined(op->buffer->data, path->Attr("buffer")->Attr("data")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I imagine this accounts for the case where a BufferRealize can act as a point of definition? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's correct. In cases where the buffer's backing allocation is defined externally, the |
||
Visit(op->buffer, path->Attr("buffer")); | ||
Visit(op->body, path->Attr("body")); | ||
} | ||
|
||
|
@@ -318,18 +305,10 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) { | |
for (size_t i = 0; i < op->match_buffers.size(); i++) { | ||
auto buf = op->match_buffers[i]->buffer; | ||
auto buffer_path = match_path->ArrayIndex(i)->Attr("buffer"); | ||
auto buffer_strides_path = buffer_path->Attr("strides"); | ||
context.push_back(WithDef(buf->data, buffer_path->Attr("data"))); | ||
// Define buffer strides and elem_offset if they are vars | ||
if (const auto* v = buf->elem_offset.as<VarNode>()) { | ||
context.push_back(WithDef(GetRef<Var>(v), buffer_path->Attr("elem_offset"))); | ||
} | ||
for (size_t i = 0; i < buf->strides.size(); ++i) { | ||
if (const auto* v = buf->strides[i].as<VarNode>()) { | ||
context.push_back(WithDef(GetRef<Var>(v), buffer_strides_path->ArrayIndex(i))); | ||
} | ||
|
||
for (auto& def : WithMatchBufferDefs(buf, buffer_path)) { | ||
context.push_back(std::move(def)); | ||
} | ||
context.push_back(WithDef(buf, buffer_path)); | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this ever used? I don't see any reads from it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There aren't any reads from it, as it holds a scoped context manager. On destruction, the
DefContext<T>
object removes items fromTIRVisitorWithPath::in_scope_definitions_
, and calls theExitDef
handler of the child class.Also, thank you for pointing this one out. When switching from
std::optional
tostd::vector
, I forgot to add awhile(context.size()) context.pop_back();
loop in case child classes rely onExitDef
being called in the reverse order fromEnterDef
.