-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] In SplitHostDevice, check for variables in thread extents #16250
[TIR] In SplitHostDevice, check for variables in thread extents #16250
Conversation
Otherwise, they would be undefined after being de-duplicated by `ConvertSSA`.
The buf reported in apache#16237 can be resolved by tracking variable usage in a thread extent.
e7f10aa
to
0ef00cb
Compare
This PR resolves #16237, and reverts an earlier fix in #16236. The unit test added for |
Environment threads must reuse the same `tir::Var` across all `AttrStmt` instances in a `PrimFunc`, but must not reuse across separate `PrimFunc`s in an `IRModule`.
Avoids issue in cortexm unit tests resulting from read/write annotations being present in the root block, followed by application of BindParams.
@jinhongyii could you please review this PR? There are many really good ones got staled over the holidays |
if blockIdx_x * 128 + threadIdx_x < seq_len: | ||
B[blockIdx_x * 128 + threadIdx_x] = A[blockIdx_x * 128 + threadIdx_x] | ||
|
||
after = tvm.tir.transform.SplitHostDevice()(before) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be better if you can also assert_structural_equal expected IRModule here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't initially, as I've been trying to be more careful about adding tests on structural equal, for two main reasons. First, they tend to be very fragile to upstream changes (e.g. breakage due to improved simplification). Second, they are difficult for a reader to identify the behavior being tested (e.g. which changes should be made in the expected output, and which indicate a bug). Since this test is primarily to ensure that there are no undefined variables in the split device/host functions, that was the functionality tested.
Added for now, though I'll need to think on whether it should be there for the long-term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it would be better if the behavior of the pass is explicitly expressed through expected IRModule. Specifically, in this PR, reader needs to understand how you remove the undefined variables and what's the function signature after the pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, and there's definitely trade-offs in both directions. I generally like the readability of expected IRModules when they are small enough to read at a glance,
…he#16250) * [TIR] In SplitHostDevice, check for variables in thread extents Otherwise, they would be undefined after being de-duplicated by `ConvertSSA`. * Revert apache#16236 The buf reported in apache#16237 can be resolved by tracking variable usage in a thread extent. * lint fixes * Update TIR well-formed checker for env thread SSA requirements Environment threads must reuse the same `tir::Var` across all `AttrStmt` instances in a `PrimFunc`, but must not reuse across separate `PrimFunc`s in an `IRModule`. * Update ConvertSSA to handle environment threads' SSA requirements * lint fix * Updated docstrings for VerifyWellFormed * Rely on script.Complete for read/writes Avoids issue in cortexm unit tests resulting from read/write annotations being present in the root block, followed by application of BindParams. * Typo fix * Added structural equal comparison in unit test
Hi @Lunderberg , I encountered some issues around the verify_well_formed analysis. it may have some conflicts with the f16 mma tensorcore tensorization. Traceback (most recent call last):
File "/home/t-leiwang/ladder_workspace/tvm_gpu_gemm/tensorirscript_imma/7.padding_mma_f16_f16_nt.py", line 235, in <module>
sch.tensorize(loop_a, intrin_group["load_a"])
File "/home/t-leiwang/mlc_workspace/mma_verify/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
return func(*args, **kwargs)
File "/home/t-leiwang/mlc_workspace/mma_verify/python/tvm/tir/schedule/schedule.py", line 2921, in tensorize
_ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member
File "/home/t-leiwang/mlc_workspace/mma_verify/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/home/t-leiwang/mlc_workspace/mma_verify/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
ValueError: Traceback (most recent call last):
54: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
53: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
52: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
51: tvm::tir::ScheduleStateNode::DebugVerify() const
50: tvm::tir::VerifyCachedFlags(tvm::tir::ScheduleState const&)
49: tvm::tir::ScheduleState::ScheduleState(tvm::IRModule, int, bool)
48: tvm::tir::VerifyWellFormed(tvm::tir::PrimFunc const&, bool)
47: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::PrimFunc const&, tvm::ObjectPath)
46: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
45: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#18}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
44: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockRealizeNode const*, tvm::ObjectPath)
43: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
42: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#17}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
41: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockNode const*, tvm::ObjectPath)
40: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
39: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
38: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
37: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
36: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
35: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
34: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
33: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
32: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
31: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
30: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
29: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
28: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
27: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#13}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
26: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::SeqStmtNode const*, tvm::ObjectPath)
25: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
24: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
23: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
22: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#13}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
21: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::SeqStmtNode const*, tvm::ObjectPath)
20: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
19: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
18: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
17: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
16: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
15: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
14: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
13: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
12: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
11: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#13}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
10: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::SeqStmtNode const*, tvm::ObjectPath)
9: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#18}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
8: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockRealizeNode const*, tvm::ObjectPath)
7: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
6: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#17}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
5: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockNode const*, tvm::ObjectPath)
4: tvm::tir::TIRVisitorWithPath::EnterDef(tvm::tir::Buffer const&, tvm::ObjectPath)
3: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
2: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#1}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
1: tvm::tir::UndefinedVarVerifier::VisitExpr_(tvm::tir::VarNode const*, tvm::ObjectPath)
0: tvm::tir::(anonymous namespace)::Verifier<tvm::tir::UndefinedVarVerifier>::VerifyStream::~VerifyStream()
File "/home/t-leiwang/mlc_workspace/mma_verify/src/tir/analysis/verify_well_formed.cc", line 100
ValueError: Invalid use of undefined variable elem_offset at <root>.body.block.body.body.body.body.body.seq[1].body.seq[2].body.body.body.seq[0].block.match_buffers[0].buffer.elem_offset. code to reproduce: https://gist.github.com/LeiWang1999/1b008e1d6f780291b45037fc9756bf8c |
looks like the verify is too strict to tensorize primitive: void VisitExpr_(const VarNode* op, ObjectPath path) override {
auto var = GetRef<Var>(op);
auto active_def = currently_defined_.find(var);
auto verify = Verify(active_def != currently_defined_.end());
verify << "ValueError: "
<< "Invalid use of undefined variable " << var << " at " << path << ".";
// Check if there was a previous definition, and append the
// location to the error message if there was. This is to aid in
// debugging, by distinguishing between a variable that is
// currently out-of-scope, and a variable that never had a
// definition in the first place.
if (auto prev_def = previously_defined_.find(var); prev_def != previously_defined_.end()) {
verify << ". While this variable was previously defined at " << prev_def->second
<< ", this definition is no longer in-scope.";
}
} |
Taking a look, looks like the the issue is here. Currently, the Based on how it is handled in |
Hi @Lunderberg , I also encountered some problems during verify_well_formed analysis. It looks like it checks buffer.data when visiting the buffer map to see if it has been defined before. You can check the case
When creating schedule on it , an error will be repoted
Could you please explain it further? |
@Lucien0 This error message occurs because the I've submitted #16655, which should resolve this issue, along with a few others that have come up in the well-formed checker. Can you test and see if it solves your issue? |
Otherwise, they would be undefined after being de-duplicated by
ConvertSSA
.