-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity][Parser] Check well-formedness in the parser #16569
[Unity][Parser] Check well-formedness in the parser #16569
Conversation
# The CallTIRInplaceAttrs cannot be constructed from the Python | ||
# API. Therefore, declaring the Expected output first, so that | ||
# the attributes can be used for the non-normalized Before. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am pretty sure it is possible to construct attrs in Python (I've done it for Relay) but it's very tedious, so I corrected the comment. It is definitely easier to rely on the one that will be constructed on the C++ side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. It's definitely more tedious than I'd like, and (I think) requires going through the make_node
API rather than having a python constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the PR, and definitely good to have the early validation. I think the main changes
- Verifying that the TIR is well-formed as well.
- Unit test for the
check_well_formed=False
attribute - Using the
check_well_formed=False
attribute instead ofrelax.BlockBuilder
for unit tests that require ill-formed inputs. This would be both for the improved readability of TVMScript, and because we may want to enable well-formed checks for some uses of therelax.BlockBuilder
as well.
# The CallTIRInplaceAttrs cannot be constructed from the Python | ||
# API. Therefore, declaring the Expected output first, so that | ||
# the attributes can be used for the non-normalized Before. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. It's definitely more tedious than I'd like, and (I think) requires going through the make_node
API rather than having a python constructor.
@Lunderberg thank you for the review. I've pushed a change to enable a well-formed check for individual Relax functions (this exposed some bugs in our tests!) and PrimFuncs, which can also be disabled with the |
@@ -71,7 +71,8 @@ def test_domain_touched(): | |||
def test_domain_touched_vector(): | |||
m = tvm.runtime.convert(128) | |||
|
|||
@T.prim_func | |||
# n is undefined |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not. We could make it well formed by adding n: T.int32
to the arguments. It looks like this test is validating that a fixed integer extent can be inferred, even when n
is dynamic.
@@ -82,7 +82,8 @@ def _get_block(f): | |||
|
|||
|
|||
def test_match_buffer(): | |||
@T.prim_func | |||
# well-formed checker complains about multiple definitions for a variable A0_s1>? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what was going on here, but the well-formed checker complains. Bug or intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say bug in the unit test. The A0_s1
is a variable generated to represent A.strides[1]
. The strides = [s, s]
should probably be strides = [s, 1]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switching the strides value causes the test to fail later (the assert matched_buffer1.strides[1] != matched_buffer2.strides[1]
fails), so I don't think that's the issue. I do not know what the intent of the test case is, but it may be necessary to make bigger changes to it.
@@ -65,7 +65,7 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None: | |||
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] | |||
|
|||
|
|||
@T.prim_func | |||
@T.prim_func(check_well_formed=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not clear to me what's wrong here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is from x
being undefined in the TIR. It appears as part of a shape in T.match_buffer
, but as x * 8*
rather than on its own.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, is this a legitimate test case then? I am not sure. It would be easy enough to add x as an argument, but would that still make sense as a test?
@@ -275,7 +275,8 @@ def before(A: T.Buffer([16, 16], "float32"), B: T.Buffer([16, 16], "float32")): | |||
for i in range(256): | |||
B_flat[i] = A_flat[i] * 2.0 | |||
|
|||
@T.prim_func(private=True) | |||
# well-formed checker complains about multiple nested definitions of B_flat? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This error was especially baffling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Lunderberg do you think anything needs to be done about this case? Is it legitimate to have a well-formedness error here?
I've determined that the segfault in the tests comes from |
@@ -71,7 +71,8 @@ def test_domain_touched(): | |||
def test_domain_touched_vector(): | |||
m = tvm.runtime.convert(128) | |||
|
|||
@T.prim_func | |||
# n is undefined |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not. We could make it well formed by adding n: T.int32
to the arguments. It looks like this test is validating that a fixed integer extent can be inferred, even when n
is dynamic.
@@ -82,7 +82,8 @@ def _get_block(f): | |||
|
|||
|
|||
def test_match_buffer(): | |||
@T.prim_func | |||
# well-formed checker complains about multiple definitions for a variable A0_s1>? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say bug in the unit test. The A0_s1
is a variable generated to represent A.strides[1]
. The strides = [s, s]
should probably be strides = [s, 1]
.
@@ -65,7 +65,7 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None: | |||
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] | |||
|
|||
|
|||
@T.prim_func | |||
@T.prim_func(check_well_formed=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is from x
being undefined in the TIR. It appears as part of a shape in T.match_buffer
, but as x * 8*
rather than on its own.
It seems the segfault is happening due to parsing this line and others like it. I'm not sure what my changes have to do with it at all (i.e., why it hasn't happened before). If I remove the line, I get a failure due to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good!
No clue why Same with |
Note that |
I did a bisect on |
@@ -116,7 +116,8 @@ def no_normal_reduction(a: T.handle, b: T.handle) -> None: | |||
B[vi] = B[vi] + A[vi, vk] | |||
|
|||
|
|||
@T.prim_func |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any idea what a fix would look like? I'd prefer not to have exceptions to the check without knowing what a well-formed version would look like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the T.evaluate(T.tvm_thread_allreduce(...))
call, k
should be replaced with vk
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing it causes the test to fail, so I think we need to change the transformation too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think this pass pre-dates schedulable TIR, and doesn't check for mapping of ForNode::loop_var
to BlockNode::iter_vars
. For now, I'd probably just marked as ill-formed, to avoid expanding the scope of this PR.
@@ -27,8 +27,9 @@ | |||
|
|||
|
|||
def opt_gemm_normalize(): | |||
@tvm.script.ir_module | |||
@tvm.script.ir_module(check_well_formed=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any idea for making this example well-formed? I tried replacing the hanging buffers with decl_buffer
but it introduced a new error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a bug in the well-formed checker. The behavior of a BufferRealize
node depends on whether it is an externally-provided buffer or not. For an externally-provided buffer, it indicates the region in which the buffer is accessed. For other buffers, it indicates the region for which the buffer must be allocated.
So, the well-formed checker should treat BufferRealize
as a point of definition if the buffer hasn't already been defined. I've submitted #16655, which resolves this issue as well as a few other failures for well-formed checks that have been reported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent! I'm glad to know that this little change has uncovered real bugs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely. I really like enabling checks to run in all cases for that reason. There's always some edge case that they find, which can lead to further improvement.
@@ -20,7 +20,8 @@ | |||
import tvm.testing | |||
|
|||
|
|||
@T.prim_func | |||
# A_local is undefined | |||
@T.prim_func(check_well_formed=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any suggestion for avoiding this undefined var? Would decl_buffer
work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using decl_buffer
would remove the undefined buffer object A_local
, but its data pointer would still be undefined. I think the A_local = T.Buffer(...)
will need to be A_local = T.alloc_buffer(...)
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to know, I'll do that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, possible parsing bug? I get this error:
ValueError: Block frame or PrimFunc frame not find. Please ensure 'T.alloc_buffer' is called under T.block() or T.prim_func()
--> /home/slyubomirsky/code/tvm/tests/python/codegen/test_inject_ptx_ldg32.py:31:15
|
31 | A_local = T.alloc_buffer((32), "float32", scope="local")
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
It is being done inside a PrimFunc, so I don't know why we would get that error. I'll add a block and see if the example will work regardless.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, you were right initially, and it should have been a T.decl_buffer
. I got it confused between the T.decl_buffer
construct in TVMScript, and the DeclBuffer
node in C++. If T.decl_buffer
has the default value of data=None
, then it will expand to a DeclBuffer
and the corresponding Allocate
node.
The T.alloc_buffer
is only valid within a block, and fills the BlockNode::alloc_buffers
field.
@@ -350,7 +351,8 @@ def kernel_2(A: T.Buffer([256], "float32")): | |||
return mod | |||
|
|||
def expected(self): | |||
@I.ir_module | |||
# complaints of duplicate definitions of threadIdx_x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you verify that this occurs in the expected
and not just the before
? The before case should be ill-formed, because threadIdx_x
is deliberately defined outside the TIR function and reused, but the expected case defines threadIdx_x
inside of each function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It happens in both No, you're right, it only happens in the before.
@tvm-bot rerun |
#16682 fixes the remaining issue in |
c67fbf2
to
d0c9570
Compare
Another failing test, presumably unrelated to my changes: |
I've determined that the above failure was introduced by commit |
Taking a look at that commit, there's a couple of things that stand out to me.
|
fe40ad9
to
e77ef6e
Compare
Hm, the only remaining failure has to do with doc generation, as the doc generator finds multiple definitions of Target. One results from my previous changes (the |
Unfortunately we find that the pr caused an outage of the MLC compilation, seems to relates to We should also carve out a UT from failed TIR which could potentially helps future regressions. opened a #16769 as a temp measure to get things back for now. We can redo the PR ideally not changing the |
@@ -607,9 +608,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) | |||
|
|||
mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); | |||
mixed_pass_list.push_back(tir::transform::SplitHostDevice()); | |||
// MergeSharedMemoryAllocations must be applied after SplitHostDevice | |||
// because the merged allocation site is at the beginning of each device function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason being here is because we cannot merge dyn share mem across kernel boundary i assume
tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
Show resolved
Hide resolved
Required for TVM compatibility after apache/tvm#16569
Required for TVM compatibility after apache/tvm#16569
* Check well-formedness in the parser * Correct packed funcs in NN frontend * Support the check_well_formed optional argument to I.ir_module * Also check well-formedness in TIR * Enable normalization for individual Relax functions and PrimFuncs * Use the error raised by the TIR well-formed checker for the message * Fix tvmscript test failures * Whitespace * Fix errors in verify_well_formed test * Include a more helpful error message * Fix TIR test failures * Address well-formed failures in test_tir_specialize * Correct well-formedness error in test_tir_analysis_oob * Correct further well-formedness failures * Remove __tvm_meta__ from test case to avoid parsing error * Avoid circular import in entryy.py * Formatting fixes * lint fix * Add pylint exceptions * Fix whitespace * Fix more failed test cases * Catch inappropriate use of decl_function instead of segfaulting * Fix test_lower.py * Mark purity in test_relax_2d_buffer_allocation.py * Mark purity in test_dma_builtin.py * Remove __tvm_meta___ from test_tir_usmp_analysis_extract_bufferinfo.py * Suppress well-formed check in test_tir_transform_convert_blocks_to_opaque.py * Remove __tvm_meta__ in test_tir_usmp_algo.py * Remove __tvm_meta__ from more USMP tests * Fix incorrect var in test_tir_transform_storage_flatten.py * Remove all remaining instances of __tvm_meta__ * Fix purity error in test_dataflow_pattern.py * Fix purity error in test_ast_printer * Fix test_arith_domain_touched example * Okay to set check_well_formed to True in test_tir_analysis_identify_mcmcpy * Define variable in test_tir_analysis_oob * Typo fix * Add explanatory comment to test case * Define the undefined vars in test_tir_transform_common_subexpr_elim * Exception no longer necessary in test_tir_transform_inject_rolling_buffer * Remove unnecessary check exemption in test_tir_transform_convert_ssa * Avoid checking exemption in test_inject_ptx_ldg32 * Note special case in test_distributed_transform_propagate_sharding * Exempt well-formed error in dlight/test_benchmark * Exempt well-formedness errors in test_ethosu/, mostly uninitialized vars * Whitespace * Include non-CUDA GPUs in IsScheduledOnGPU * Fix thread binding bug by changing thread binding var dtype * Include overrides in test_runtime_builtin_paged_attention_kv_cache.py * add exemptions in test_ethosu/test_replace_conv2d * Add more ethosu exemptions * More exemptions for ethosu tests * Remove unused reference * Indicate purity in test_transform_rewrite_cuda_graph * Indicate purity in test_transform_normalize * Reorder MergeSharedMemoryAllocations in GPU codegen * Add target parameter for FP8StorageLegalize and FP8ComputeLegalize * Don't re-import Target in tvm/tir/transform/transform.py
Required for TVM compatibility after apache/tvm#16569
As discussed in several TVM Open Development meetings, we were permitting invalid programs to be parsed by not checking well-formedness during parsing. This PR makes a small change to the parser to check well-formedness in both Relax and TIR proactively, though this required correcting bugs in tests. Those bugs should not have been there in the first place!
I had held off on making this PR until this PR on MLC-LLM, which ensured that there would not be invalid constructs in models constructed by MLC-LLM.
One issue is that the well-formed checker does not create the most readable error messages. If the normalizer's error reporting could also be used in the well-formed checker, that would be a useful change to make as well.