Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][Parser] Check well-formedness in the parser #16569

Merged
merged 58 commits into from
Mar 21, 2024

Conversation

slyubomirsky
Copy link
Contributor

@slyubomirsky slyubomirsky commented Feb 14, 2024

As discussed in several TVM Open Development meetings, we were permitting invalid programs to be parsed by not checking well-formedness during parsing. This PR makes a small change to the parser to check well-formedness in both Relax and TIR proactively, though this required correcting bugs in tests. Those bugs should not have been there in the first place!

I had held off on making this PR until this PR on MLC-LLM, which ensured that there would not be invalid constructs in models constructed by MLC-LLM.

One issue is that the well-formed checker does not create the most readable error messages. If the normalizer's error reporting could also be used in the well-formed checker, that would be a useful change to make as well.

Comment on lines -264 to -266
# The CallTIRInplaceAttrs cannot be constructed from the Python
# API. Therefore, declaring the Expected output first, so that
# the attributes can be used for the non-normalized Before.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure it is possible to construct attrs in Python (I've done it for Relay) but it's very tedious, so I corrected the comment. It is definitely easier to rely on the one that will be constructed on the C++ side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. It's definitely more tedious than I'd like, and (I think) requires going through the make_node API rather than having a python constructor.

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the PR, and definitely good to have the early validation. I think the main changes

  1. Verifying that the TIR is well-formed as well.
  2. Unit test for the check_well_formed=False attribute
  3. Using the check_well_formed=False attribute instead of relax.BlockBuilder for unit tests that require ill-formed inputs. This would be both for the improved readability of TVMScript, and because we may want to enable well-formed checks for some uses of the relax.BlockBuilder as well.

python/tvm/script/parser/core/entry.py Outdated Show resolved Hide resolved
python/tvm/script/parser/core/entry.py Outdated Show resolved Hide resolved
python/tvm/script/parser/core/entry.py Outdated Show resolved Hide resolved
tests/python/relax/test_transform_normalize_global_var.py Outdated Show resolved Hide resolved
Comment on lines -264 to -266
# The CallTIRInplaceAttrs cannot be constructed from the Python
# API. Therefore, declaring the Expected output first, so that
# the attributes can be used for the non-normalized Before.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. It's definitely more tedious than I'd like, and (I think) requires going through the make_node API rather than having a python constructor.

@slyubomirsky slyubomirsky changed the title [Relax][Parser] Check well-formedness in the parser [Unity][Parser] Check well-formedness in the parser Feb 14, 2024
@slyubomirsky
Copy link
Contributor Author

slyubomirsky commented Feb 14, 2024

@Lunderberg thank you for the review. I've pushed a change to enable a well-formed check for individual Relax functions (this exposed some bugs in our tests!) and PrimFuncs, which can also be disabled with the check_well_formed argument, but it led to some TIR well-formedness failures. I am not sure what could be wrong with most of them so I'd appreciate a look if you have time once the CI finishes. Oddly, setting check_well_formed to false on those did not correct the errors, so I'm not sure what the issue was.

@@ -71,7 +71,8 @@ def test_domain_touched():
def test_domain_touched_vector():
m = tvm.runtime.convert(128)

@T.prim_func
# n is undefined
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not. We could make it well formed by adding n: T.int32 to the arguments. It looks like this test is validating that a fixed integer extent can be inferred, even when n is dynamic.

@@ -82,7 +82,8 @@ def _get_block(f):


def test_match_buffer():
@T.prim_func
# well-formed checker complains about multiple definitions for a variable A0_s1>?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what was going on here, but the well-formed checker complains. Bug or intended?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say bug in the unit test. The A0_s1 is a variable generated to represent A.strides[1]. The strides = [s, s] should probably be strides = [s, 1].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching the strides value causes the test to fail later (the assert matched_buffer1.strides[1] != matched_buffer2.strides[1] fails), so I don't think that's the issue. I do not know what the intent of the test case is, but it may be necessary to make bigger changes to it.

@@ -65,7 +65,7 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


@T.prim_func
@T.prim_func(check_well_formed=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear to me what's wrong here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is from x being undefined in the TIR. It appears as part of a shape in T.match_buffer, but as x * 8* rather than on its own.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, is this a legitimate test case then? I am not sure. It would be easy enough to add x as an argument, but would that still make sense as a test?

@@ -275,7 +275,8 @@ def before(A: T.Buffer([16, 16], "float32"), B: T.Buffer([16, 16], "float32")):
for i in range(256):
B_flat[i] = A_flat[i] * 2.0

@T.prim_func(private=True)
# well-formed checker complains about multiple nested definitions of B_flat?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error was especially baffling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lunderberg do you think anything needs to be done about this case? Is it legitimate to have a well-formedness error here?

@slyubomirsky
Copy link
Contributor Author

I've determined that the segfault in the tests comes from tests/python/tir-transform/test_transform_inject_rolling_buffer.py. According to the backtrace, this comes from checking the StructInfo for a global variable, but oddly, it does not appear to come from checking well-formedness. Any idea what could be the issue? @Lunderberg

@@ -71,7 +71,8 @@ def test_domain_touched():
def test_domain_touched_vector():
m = tvm.runtime.convert(128)

@T.prim_func
# n is undefined
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not. We could make it well formed by adding n: T.int32 to the arguments. It looks like this test is validating that a fixed integer extent can be inferred, even when n is dynamic.

@@ -82,7 +82,8 @@ def _get_block(f):


def test_match_buffer():
@T.prim_func
# well-formed checker complains about multiple definitions for a variable A0_s1>?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say bug in the unit test. The A0_s1 is a variable generated to represent A.strides[1]. The strides = [s, s] should probably be strides = [s, 1].

@@ -65,7 +65,7 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


@T.prim_func
@T.prim_func(check_well_formed=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is from x being undefined in the TIR. It appears as part of a shape in T.match_buffer, but as x * 8* rather than on its own.

@slyubomirsky
Copy link
Contributor Author

slyubomirsky commented Feb 23, 2024

It seems the segfault is happening due to parsing this line and others like it. I'm not sure what my changes have to do with it at all (i.e., why it hasn't happened before). If I remove the line, I get a failure due to tensor_2 being undefined, so I would appreciate advice for how to fix this test case. (Edit: I just added tensor_2 as an argument. Unclear why the line was there in the first place.)

Copy link
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

@slyubomirsky
Copy link
Contributor Author

slyubomirsky commented Feb 27, 2024

No clue why tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py::test_thread_axis2 is failing. There is no well-formed error there, but I get a complaint about dtypes not matching (for the loop iterator i0_i1_i2_i3_fused_2). Not sure why it wouldn't have failed before.

Same with tests/python/tir-transform/test_tir_transform_hoist_if.py::test_hoisting_block_scope_4

@slyubomirsky
Copy link
Contributor Author

slyubomirsky commented Feb 27, 2024

Note that tests/python/tir-transform/test_tir_transform_hoist_if.py::test_hoisting_block_scope_4 and test_tir_transform_force_narrow_index_to_i32.py::test_thread_axis2 also fail on mainline, so we might have to fix real (unrelated) bugs there or disable the tests. The same is likely true of tests/python/tir-transform/test_transform_default_gpu_schedule.py::test_add_on_metal.

@Lunderberg
Copy link
Contributor

I did a bisect on test_tir_transform_force_narrow_index_to_i32.py::test_thread_axis2, and it's been broken for quite some time. This merge commit from main into unity broke it, and it has been broken ever since. There's some itervar-related changes in src/script/ir_builder/tir/ir.cc that I suspect to be the cause. They remove an unconditional usage of int32 dtype, which could be previously been assumed.

@@ -116,7 +116,8 @@ def no_normal_reduction(a: T.handle, b: T.handle) -> None:
B[vi] = B[vi] + A[vi, vk]


@T.prim_func
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea what a fix would look like? I'd prefer not to have exceptions to the check without knowing what a well-formed version would look like

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the T.evaluate(T.tvm_thread_allreduce(...)) call, k should be replaced with vk.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing it causes the test to fail, so I think we need to change the transformation too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think this pass pre-dates schedulable TIR, and doesn't check for mapping of ForNode::loop_var to BlockNode::iter_vars. For now, I'd probably just marked as ill-formed, to avoid expanding the scope of this PR.

@@ -27,8 +27,9 @@


def opt_gemm_normalize():
@tvm.script.ir_module
@tvm.script.ir_module(check_well_formed=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea for making this example well-formed? I tried replacing the hanging buffers with decl_buffer but it introduced a new error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bug in the well-formed checker. The behavior of a BufferRealize node depends on whether it is an externally-provided buffer or not. For an externally-provided buffer, it indicates the region in which the buffer is accessed. For other buffers, it indicates the region for which the buffer must be allocated.

So, the well-formed checker should treat BufferRealize as a point of definition if the buffer hasn't already been defined. I've submitted #16655, which resolves this issue as well as a few other failures for well-formed checks that have been reported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent! I'm glad to know that this little change has uncovered real bugs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely. I really like enabling checks to run in all cases for that reason. There's always some edge case that they find, which can lead to further improvement.

@@ -20,7 +20,8 @@
import tvm.testing


@T.prim_func
# A_local is undefined
@T.prim_func(check_well_formed=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any suggestion for avoiding this undefined var? Would decl_buffer work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using decl_buffer would remove the undefined buffer object A_local, but its data pointer would still be undefined. I think the A_local = T.Buffer(...) will need to be A_local = T.alloc_buffer(...) instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know, I'll do that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, possible parsing bug? I get this error:

ValueError: Block frame or PrimFunc frame not find. Please ensure 'T.alloc_buffer' is called under T.block() or T.prim_func()
 --> /home/slyubomirsky/code/tvm/tests/python/codegen/test_inject_ptx_ldg32.py:31:15
    |  
 31 |      A_local = T.alloc_buffer((32), "float32", scope="local")
    |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

It is being done inside a PrimFunc, so I don't know why we would get that error. I'll add a block and see if the example will work regardless.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, you were right initially, and it should have been a T.decl_buffer. I got it confused between the T.decl_buffer construct in TVMScript, and the DeclBuffer node in C++. If T.decl_buffer has the default value of data=None, then it will expand to a DeclBuffer and the corresponding Allocate node.

The T.alloc_buffer is only valid within a block, and fills the BlockNode::alloc_buffers field.

@@ -350,7 +351,8 @@ def kernel_2(A: T.Buffer([256], "float32")):
return mod

def expected(self):
@I.ir_module
# complaints of duplicate definitions of threadIdx_x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you verify that this occurs in the expected and not just the before? The before case should be ill-formed, because threadIdx_x is deliberately defined outside the TIR function and reused, but the expected case defines threadIdx_x inside of each function.

Copy link
Contributor Author

@slyubomirsky slyubomirsky Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It happens in both No, you're right, it only happens in the before.

@yongwww
Copy link
Member

yongwww commented Mar 4, 2024

@tvm-bot rerun

@slyubomirsky
Copy link
Contributor Author

#16682 fixes the remaining issue in test_tir_transform_hoist_if.py.

@slyubomirsky slyubomirsky force-pushed the parser-check-well-formed branch from c67fbf2 to d0c9570 Compare March 6, 2024 23:09
@slyubomirsky
Copy link
Contributor Author

Another failing test, presumably unrelated to my changes: tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py::test_vectorize_cp_async_in_if_then_else. It complains that the var data_im2col_reindex_shared_dyn is undefined, which appears wrong, since it is defined with an alloc_buffer.

@slyubomirsky
Copy link
Contributor Author

slyubomirsky commented Mar 12, 2024

I've determined that the above failure was introduced by commit ff0b99c5ce4371ec966cd4fa07ae36351faf2a5e. In particular, reordering MergeSharedMemoryAllocations in src/driver/driver_api.cc triggers it--reverting those changes fixes it. However, I don't want to reverse the reordering because there is a comment on it that implies it was done for a reason, so I will see if there is a bug in that pass.

@Lunderberg
Copy link
Contributor

Taking a look at that commit, there's a couple of things that stand out to me.

  • Related to MergeSharedMemoryAllocations

    • The MergeSharedMemoryAllocations is a copy/paste of StorageRewrite with modifications. The StorageRewrite pass is applied earlier and already handled merging of shared memory allocations, so I'm not sure why the separate pass is needed.
      • It looks like there's some different descriptions of StorageRewrite. This 2021 comment states that StorageRewrite cannot merge allocations of different dtypes. However, this 2018 PR added memory-sharing of allocations of different dtypes for StorageRewrite. This may be related to static vs dynamic shapes.
    • The StorageRewrite pass only merged allocations that were within the same attr::thread_extent annotation, and would be part of the same kernel. Assuming that MergeSharedMemoryAllocations kept this behavior from StorageRewrite, moving it after SplitHostDevice shouldn't be necessary.
    • I'd have to dig into it to be sure, but using MergeSharedMemoryAllocations after ThreadSync("warp") seems like it could cause an issue. Applying MergeSharedMemoryAllocations (merge two buffers into one) feels like it would undo the benefit of ThreadSync("warp") (synchronize to avoid contention of a single buffer).
  • Related to HoistIfThenElse

    • The comment about requiring HoistIfThenElse to be before UnrollLoop doesn't make sense to me. Any condition that could be statically proven for all indices prior to UnrollLoop could also be statically proven for a specific loop index. The end result of UnrollLoop and HoistIfThenElse should be the same regardless of the ordering.
      • Slight caveat: Unless the HoistIfThenElse causes a different decision to be made by the automatic unroll heuristics in UnrollLoop. However, those are disabled by default, and are not enabled by the unit test.
    • Moving HoistIfThenElse changed it's ordering relative to user-specified passes in the "tir.add_lower_pass" configuration. While I don't know of any usage of this functionality, it's a bit worrying that it occurred without discussion.
    • The new function attribute "tir.HoistIfThenElseExprWithBlock" overrides the user's configuration for "tir.HoistIfThenElse". If the user had selected a stronger hoisting to be performed (e.g. hoisting tir::IfThenElse statements), it would no longer be applied.
    • The configuration used when the new function attribute "tir.HoistIfThenElseExprWithBlock" is set doesn't make sense for this pass's location in the lowering flow. The HoistIfThenElse pass occurs after ConvertBlocksToOpaque has removed all block variables, so the HoistedConditionals::kUsingBlockVar flag shouldn't be required.
  • Related to unit tests

    • The unit tests added in test_gpu_low_batch_gemv.py only exercise the meta-schedule functionality, and do not exercise any of the changes made to driver_api.cc or to HoistIfThenElse.

@slyubomirsky slyubomirsky force-pushed the parser-check-well-formed branch from fe40ad9 to e77ef6e Compare March 19, 2024 00:53
@slyubomirsky
Copy link
Contributor Author

slyubomirsky commented Mar 21, 2024

Hm, the only remaining failure has to do with doc generation, as the doc generator finds multiple definitions of Target. One results from my previous changes (the import * in python/tvm/tir/transform/__init__.py will re-export the import of target), but I am not sure what is causing the others. I will see if this is entirely due to the re-import.

@slyubomirsky slyubomirsky merged commit 6c701fe into apache:main Mar 21, 2024
18 checks passed
tqchen added a commit that referenced this pull request Mar 22, 2024
@tqchen
Copy link
Member

tqchen commented Mar 22, 2024

Unfortunately we find that the pr caused an outage of the MLC compilation, seems to relates to MergeSharedMemoryAllocations location change.

We should also carve out a UT from failed TIR which could potentially helps future regressions.

opened a #16769 as a temp measure to get things back for now. We can redo the PR ideally not changing the MergeSharedMemoryAllocations, or confirm that the carved out tir cases can pass

cc @slyubomirsky @jinhongyii @yongwww

@@ -607,9 +608,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
// MergeSharedMemoryAllocations must be applied after SplitHostDevice
// because the merged allocation site is at the beginning of each device function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason being here is because we cannot merge dyn share mem across kernel boundary i assume

Lunderberg added a commit to octoml/mlc-llm that referenced this pull request Mar 27, 2024
Required for TVM compatibility after apache/tvm#16569
csullivan pushed a commit to octoml/mlc-llm that referenced this pull request Apr 2, 2024
Required for TVM compatibility after apache/tvm#16569
thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
* Check well-formedness in the parser

* Correct packed funcs in NN frontend

* Support the check_well_formed optional argument to I.ir_module

* Also check well-formedness in TIR

* Enable normalization for individual Relax functions and PrimFuncs

* Use the error raised by the TIR well-formed checker for the message

* Fix tvmscript test failures

* Whitespace

* Fix errors in verify_well_formed test

* Include a more helpful error message

* Fix TIR test failures

* Address well-formed failures in test_tir_specialize

* Correct well-formedness error in test_tir_analysis_oob

* Correct further well-formedness failures

* Remove __tvm_meta__ from test case to avoid parsing error

* Avoid circular import in entryy.py

* Formatting fixes

* lint fix

* Add pylint exceptions

* Fix whitespace

* Fix more failed test cases

* Catch inappropriate use of decl_function instead of segfaulting

* Fix test_lower.py

* Mark purity in test_relax_2d_buffer_allocation.py

* Mark purity in test_dma_builtin.py

* Remove __tvm_meta___ from test_tir_usmp_analysis_extract_bufferinfo.py

* Suppress well-formed check in test_tir_transform_convert_blocks_to_opaque.py

* Remove __tvm_meta__ in test_tir_usmp_algo.py

* Remove __tvm_meta__ from more USMP tests

* Fix incorrect var in test_tir_transform_storage_flatten.py

* Remove all remaining instances of __tvm_meta__

* Fix purity error in test_dataflow_pattern.py

* Fix purity error in test_ast_printer

* Fix test_arith_domain_touched example

* Okay to set check_well_formed to True in test_tir_analysis_identify_mcmcpy

* Define variable in test_tir_analysis_oob

* Typo fix

* Add explanatory comment to test case

* Define the undefined vars in test_tir_transform_common_subexpr_elim

* Exception no longer necessary in test_tir_transform_inject_rolling_buffer

* Remove unnecessary check exemption in test_tir_transform_convert_ssa

* Avoid checking exemption in test_inject_ptx_ldg32

* Note special case in test_distributed_transform_propagate_sharding

* Exempt well-formed error in dlight/test_benchmark

* Exempt well-formedness errors in test_ethosu/, mostly uninitialized vars

* Whitespace

* Include non-CUDA GPUs in IsScheduledOnGPU

* Fix thread binding bug by changing thread binding var dtype

* Include overrides in test_runtime_builtin_paged_attention_kv_cache.py

* add exemptions in test_ethosu/test_replace_conv2d

* Add more ethosu exemptions

* More exemptions for ethosu tests

* Remove unused reference

* Indicate purity in test_transform_rewrite_cuda_graph

* Indicate purity in test_transform_normalize

* Reorder MergeSharedMemoryAllocations in GPU codegen

* Add target parameter for FP8StorageLegalize and FP8ComputeLegalize

* Don't re-import Target in tvm/tir/transform/transform.py
sunggg pushed a commit to octoml/mlc-llm that referenced this pull request Apr 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants