Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] In SplitHostDevice, check for variables in thread extents #16250

Merged
merged 10 commits into from
Jan 3, 2024

Conversation

Lunderberg
Copy link
Contributor

Otherwise, they would be undefined after being de-duplicated by ConvertSSA.

Otherwise, they would be undefined after being de-duplicated by
`ConvertSSA`.
The buf reported in apache#16237 can be resolved by tracking variable usage
in a thread extent.
@Lunderberg Lunderberg force-pushed the split_host_device_thread_extent branch from e7f10aa to 0ef00cb Compare December 18, 2023 21:49
@Lunderberg Lunderberg changed the title [Draft][TIR] In SplitHostDevice, check for variables in thread extents [TIR] In SplitHostDevice, check for variables in thread extents Dec 18, 2023
@Lunderberg Lunderberg marked this pull request as ready for review December 18, 2023 21:49
@Lunderberg
Copy link
Contributor Author

This PR resolves #16237, and reverts an earlier fix in #16236. The unit test added for SplitHostDevice uses the well-formed checker to variables are defined at their point-of-use. Since this was neither an existing part of tir.analysis.verify_well_formed, and an initial implementation using VarUseDefAnalysis didn't include the location of undefined usage, this PR also includes TIRVisitorWithPath, which tracks the location within an IRModule for use in error messages.

Environment threads must reuse the same `tir::Var` across all
`AttrStmt` instances in a `PrimFunc`, but must not reuse across
separate `PrimFunc`s in an `IRModule`.
Avoids issue in cortexm unit tests resulting from read/write
annotations being present in the root block, followed by application
of BindParams.
@junrushao
Copy link
Member

junrushao commented Jan 1, 2024

@jinhongyii could you please review this PR? There are many really good ones got staled over the holidays

if blockIdx_x * 128 + threadIdx_x < seq_len:
B[blockIdx_x * 128 + threadIdx_x] = A[blockIdx_x * 128 + threadIdx_x]

after = tvm.tir.transform.SplitHostDevice()(before)
Copy link
Contributor

@jinhongyii jinhongyii Jan 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better if you can also assert_structural_equal expected IRModule here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't initially, as I've been trying to be more careful about adding tests on structural equal, for two main reasons. First, they tend to be very fragile to upstream changes (e.g. breakage due to improved simplification). Second, they are difficult for a reader to identify the behavior being tested (e.g. which changes should be made in the expected output, and which indicate a bug). Since this test is primarily to ensure that there are no undefined variables in the split device/host functions, that was the functionality tested.

Added for now, though I'll need to think on whether it should be there for the long-term.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it would be better if the behavior of the pass is explicitly expressed through expected IRModule. Specifically, in this PR, reader needs to understand how you remove the undefined variables and what's the function signature after the pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, and there's definitely trade-offs in both directions. I generally like the readability of expected IRModules when they are small enough to read at a glance,

@Lunderberg Lunderberg merged commit eb15d04 into apache:main Jan 3, 2024
20 checks passed
@Lunderberg Lunderberg deleted the split_host_device_thread_extent branch January 3, 2024 14:03
junrushao pushed a commit to junrushao/tvm that referenced this pull request Jan 7, 2024
…he#16250)

* [TIR] In SplitHostDevice, check for variables in thread extents

Otherwise, they would be undefined after being de-duplicated by
`ConvertSSA`.

* Revert apache#16236

The buf reported in apache#16237 can be resolved by tracking variable usage
in a thread extent.

* lint fixes

* Update TIR well-formed checker for env thread SSA requirements

Environment threads must reuse the same `tir::Var` across all
`AttrStmt` instances in a `PrimFunc`, but must not reuse across
separate `PrimFunc`s in an `IRModule`.

* Update ConvertSSA to handle environment threads' SSA requirements

* lint fix

* Updated docstrings for VerifyWellFormed

* Rely on script.Complete for read/writes

Avoids issue in cortexm unit tests resulting from read/write
annotations being present in the root block, followed by application
of BindParams.

* Typo fix

* Added structural equal comparison in unit test
@LeiWang1999
Copy link
Contributor

Hi @Lunderberg , I encountered some issues around the verify_well_formed analysis. it may have some conflicts with the f16 mma tensorcore tensorization.

Traceback (most recent call last):
  File "/home/t-leiwang/ladder_workspace/tvm_gpu_gemm/tensorirscript_imma/7.padding_mma_f16_f16_nt.py", line 235, in <module>
    sch.tensorize(loop_a, intrin_group["load_a"])
  File "/home/t-leiwang/mlc_workspace/mma_verify/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
    return func(*args, **kwargs)
  File "/home/t-leiwang/mlc_workspace/mma_verify/python/tvm/tir/schedule/schedule.py", line 2921, in tensorize
    _ffi_api.ScheduleTensorize(  # type: ignore # pylint: disable=no-member
  File "/home/t-leiwang/mlc_workspace/mma_verify/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/t-leiwang/mlc_workspace/mma_verify/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  54: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  53: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  52: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  51: tvm::tir::ScheduleStateNode::DebugVerify() const
  50: tvm::tir::VerifyCachedFlags(tvm::tir::ScheduleState const&)
  49: tvm::tir::ScheduleState::ScheduleState(tvm::IRModule, int, bool)
  48: tvm::tir::VerifyWellFormed(tvm::tir::PrimFunc const&, bool)
  47: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::PrimFunc const&, tvm::ObjectPath)
  46: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  45: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#18}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  44: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockRealizeNode const*, tvm::ObjectPath)
  43: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  42: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#17}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  41: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockNode const*, tvm::ObjectPath)
  40: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  39: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  38: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  37: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  36: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  35: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  34: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  33: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  32: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  31: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  30: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  29: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  28: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  27: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#13}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  26: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::SeqStmtNode const*, tvm::ObjectPath)
  25: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  24: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  23: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  22: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#13}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  21: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::SeqStmtNode const*, tvm::ObjectPath)
  20: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  19: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  18: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  17: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  16: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  15: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  14: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  13: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::ForNode const*, tvm::ObjectPath)
  12: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  11: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#13}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  10: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::SeqStmtNode const*, tvm::ObjectPath)
  9: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#18}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  8: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockRealizeNode const*, tvm::ObjectPath)
  7: tvm::tir::TIRVisitorWithPath::Visit(tvm::tir::Stmt const&, tvm::ObjectPath)
  6: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)#17}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  5: tvm::tir::TIRVisitorWithPath::VisitStmt_(tvm::tir::BlockNode const*, tvm::ObjectPath)
  4: tvm::tir::TIRVisitorWithPath::EnterDef(tvm::tir::Buffer const&, tvm::ObjectPath)
  3: tvm::tir::TIRVisitorWithPath::Visit(tvm::PrimExpr const&, tvm::ObjectPath)
  2: tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)#1}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<void (tvm::PrimExpr const&, tvm::ObjectPath)>*, tvm::ObjectPath)
  1: tvm::tir::UndefinedVarVerifier::VisitExpr_(tvm::tir::VarNode const*, tvm::ObjectPath)
  0: tvm::tir::(anonymous namespace)::Verifier<tvm::tir::UndefinedVarVerifier>::VerifyStream::~VerifyStream()
  File "/home/t-leiwang/mlc_workspace/mma_verify/src/tir/analysis/verify_well_formed.cc", line 100
ValueError: Invalid use of undefined variable elem_offset at <root>.body.block.body.body.body.body.body.seq[1].body.seq[2].body.body.body.seq[0].block.match_buffers[0].buffer.elem_offset.

code to reproduce: https://gist.github.com/LeiWang1999/1b008e1d6f780291b45037fc9756bf8c

@LeiWang1999
Copy link
Contributor

looks like the verify is too strict to tensorize primitive:

  void VisitExpr_(const VarNode* op, ObjectPath path) override {
    auto var = GetRef<Var>(op);

    auto active_def = currently_defined_.find(var);
    auto verify = Verify(active_def != currently_defined_.end());
    verify << "ValueError: "
           << "Invalid use of undefined variable " << var << " at " << path << ".";

    // Check if there was a previous definition, and append the
    // location to the error message if there was.  This is to aid in
    // debugging, by distinguishing between a variable that is
    // currently out-of-scope, and a variable that never had a
    // definition in the first place.
    if (auto prev_def = previously_defined_.find(var); prev_def != previously_defined_.end()) {
      verify << ".  While this variable was previously defined at " << prev_def->second
             << ", this definition is no longer in-scope.";
    }
  }

@Lunderberg
Copy link
Contributor Author

Taking a look, looks like the the issue is here. Currently, the BlockNode::match_buffers is the point of definition for the view's data value, and for the view itself.

Based on how it is handled in LowerMatchBuffer, it looks like the shape, strides, and elem_offset of the matched buffers should be treated the same as function arguments. That is, they are points of definition if the expression is a variable that hasn't already been defined, but otherwise are points of use.

@Lucien0
Copy link
Contributor

Lucien0 commented Feb 22, 2024

Hi @Lunderberg , I also encountered some problems during verify_well_formed analysis. It looks like it checks buffer.data when visiting the buffer map to see if it has been defined before.

You can check the case test_tvmscript_printer_tir::test_prim_func_no_sugar_shared_buffer_data(), which tir like:

@T.prim_func
def main(a: T.handle, b: T.handle):
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (256, 256), data=A.data)
    T.evaluate(0)

When creating schedule on it , an error will be repoted

ValueError: TIR is ill-formed, due to multiple nested definitions of variable A.  It was first defined at <root>.buffer_map[<root>.params[0]].data, and was re-defined at <root>.buffer_map[<root>.params[1]].data

Could you please explain it further?

@Lunderberg
Copy link
Contributor Author

@Lucien0 This error message occurs because the VerifyWellFormed checker treats all elements of the PrimFunc::buffer_map as independent non-aliasing buffers. However, this is a stronger restriction that the TIR language requires. It is legal for a PrimFunc to require two DLTensor arguments, and to require that the two DLTensors share the same backing allocation. This is what your example states. However, the well-formed checker doesn't currently handle this, and erroneously expects every DLTensor argument to define an independent buffer/allocation.

I've submitted #16655, which should resolve this issue, along with a few others that have come up in the well-formed checker. Can you test and see if it solves your issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants