-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] CSE-TIR Pass - More deterministic behavior #10663
[TIR] CSE-TIR Pass - More deterministic behavior #10663
Conversation
cc some folks from #9482: @masahi @FranckQC @mbs-octoml @jroesch |
@AndrewZhaoLuo maybe we can also solve this by changing the comparison functor in
|
Hmm, so the unordered map doesn't have an ordering so all it needs is some way of determining equality of elements. An alternative would be In general though yeah, we need to be on the lookout for iterating over unordered collections cause it will be non-determinstic so might be good for maintainability to use std::map I'll try refactoring |
The actual order doesn't matter, does it? |
ah sorry, I was confused with something, "but it does have some ordering internally determined by the comparison functor" is not correct. |
@AndrewZhaoLuo I think we can merge this as is for now. The CI failed on the mac os, probably need to kick another job. |
Hi, Sorry I couldn't discuss this with you last Thursday/Friday. Thanks for improving the CSE! Yes, indeed, iterating through the hash table will not necessary lead to accessing its elements always in the same order. That's because the iteration order depends on the hashes (and we use I guess it is not an issue as far as the CSE is concerned, because after producing this hashtable, we use it to produce a vector of semantic entities (which are pairs <expr, count>). So the order of the end result (the vector) is non-deterministic only for elements with the same size / complexity, for which it does not matter (again, only for the CSE) in which order we common them. (i.e. it should not add or remove opportunities for commoning more later). Said differently, this non-determinism should appear for orthogonal choices.
for which we could produce:
Or the alternative version were (c+d) gets introduced before (a+b). In practice, what is chosen currently depends on the addresses of the free variables 'a', 'b', 'c' and 'd' because they are used for the hashes. I did not think that having this level of non-determinism would be an issue, but you're clearly right that its better to always produce exactly the same resulting TIR code, that just makes testing easier. Sorry I did not think about it when I wrote it! According to the doc about Many thanks for having spotted that and for having fixed it! Kind regards, Franck |
* iterate through sorted keys * masa comments -- simplify iteration * test * tests * simplify vector construciton * jostle ci
…orm - avoids comparison of terms (#11574) The CSE pass had been designed for potentially allowing comparisons (and commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where **the notion of being equivalent was customizable, and no assumption was made about it**. That means that the implementation of the equivalence test function `EquivalentTerms()` - which was at the moment just calling the syntactical equality test `EqualTerms()` - could be replaced later by a cleverer equality test. However, having such a generic way of comparing elements meant that in the function `SyntacticToSemanticComputations()`, where we were going from a hashtable of syntactical entities to what I called a vector of "semantical entites" (which are just canonical forms/representants of classes of equivalence of terms), **the only way was to compare each pair**. That resulted in a quadratic behavior of this function, but there was no way around it as in order to merge equivalent entities into their class of equivalence, we had to compare them. **This PR essentially does the following:** - When computing the classes of equivalences of terms (therefore transforming a ComputationTable (i.e. a hashtable) into a vector of classes of equivalence) : **instead of comparing each pair of terms, relies on a normalization procedure to obtain a normal form for each of them**. That transforms a small part of the algorithm that was quadratic to n.logn. However, it's difficult to see improvements in practice, in particular for average sized programs, as that part was a "small" quadratic to a "big" n.logn (finding things in a hash-table, copying it to a vector, etc). It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a complexity of ~O(3n + n.logn), so potential gains would only be expected for very large programs. - Completely gives the user the possibility to turn ON/OFF the semantical comparisons of terms. It is turned OFF by default (as it's quite longer to compile with it ON, unsurprisingly), which means that by default, the equivalence coincides with the (syntactical) equality of terms. As the pass was written with the possibility to do these additional commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that completely, up to the Python user who can now turn that ON if he wants to. But again, it is OFF by default, so no real change on that. To run it ON, simply do: `with tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):` before calling `build()` - When this boolean is set to ON, it uses a simple implementation of the normalization function with equivalences that uses `arith::Analyzer::Simplify` as noted by in #10544 . Note that this is not a real normalization procedure as it is incomplete (i.e., it is not guarantee to converge to the normal form), but it is correct, and it works well with most properties : associativity of +, distributivity of * on +, etc. - Clarifies and enhance the test base for the pass. In particular, it adds the tests that were written in #10544 but which did not make it through. - Also add the test ( https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1 ) demonstrating the (older) non-deterministic lowering and put it into a proper test, as I found it useful for making sure that this does not happen again. It has been copied from #10663 and only slightly adapted (in particular for doing the comparison of hashes automatically instead of printing them and relying on a human to compare them).
…orm - avoids comparison of terms (apache#11574) The CSE pass had been designed for potentially allowing comparisons (and commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where **the notion of being equivalent was customizable, and no assumption was made about it**. That means that the implementation of the equivalence test function `EquivalentTerms()` - which was at the moment just calling the syntactical equality test `EqualTerms()` - could be replaced later by a cleverer equality test. However, having such a generic way of comparing elements meant that in the function `SyntacticToSemanticComputations()`, where we were going from a hashtable of syntactical entities to what I called a vector of "semantical entites" (which are just canonical forms/representants of classes of equivalence of terms), **the only way was to compare each pair**. That resulted in a quadratic behavior of this function, but there was no way around it as in order to merge equivalent entities into their class of equivalence, we had to compare them. **This PR essentially does the following:** - When computing the classes of equivalences of terms (therefore transforming a ComputationTable (i.e. a hashtable) into a vector of classes of equivalence) : **instead of comparing each pair of terms, relies on a normalization procedure to obtain a normal form for each of them**. That transforms a small part of the algorithm that was quadratic to n.logn. However, it's difficult to see improvements in practice, in particular for average sized programs, as that part was a "small" quadratic to a "big" n.logn (finding things in a hash-table, copying it to a vector, etc). It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a complexity of ~O(3n + n.logn), so potential gains would only be expected for very large programs. - Completely gives the user the possibility to turn ON/OFF the semantical comparisons of terms. It is turned OFF by default (as it's quite longer to compile with it ON, unsurprisingly), which means that by default, the equivalence coincides with the (syntactical) equality of terms. As the pass was written with the possibility to do these additional commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that completely, up to the Python user who can now turn that ON if he wants to. But again, it is OFF by default, so no real change on that. To run it ON, simply do: `with tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):` before calling `build()` - When this boolean is set to ON, it uses a simple implementation of the normalization function with equivalences that uses `arith::Analyzer::Simplify` as noted by in apache#10544 . Note that this is not a real normalization procedure as it is incomplete (i.e., it is not guarantee to converge to the normal form), but it is correct, and it works well with most properties : associativity of +, distributivity of * on +, etc. - Clarifies and enhance the test base for the pass. In particular, it adds the tests that were written in apache#10544 but which did not make it through. - Also add the test ( https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1 ) demonstrating the (older) non-deterministic lowering and put it into a proper test, as I found it useful for making sure that this does not happen again. It has been copied from apache#10663 and only slightly adapted (in particular for doing the comparison of hashes automatically instead of printing them and relying on a human to compare them).
Running the CSE-TIR pass led to some variation in lowering to TIR from the same schedule multiple times. I tracked this down to iterating over an unordered map. I think having consistency in lowering to TIR is a desirable property to have so I made the offending code a little better in this regard.
Example test script:
https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1
^--- Running the above off main produces different hashes every time, running after this diff it produces the same hash.
There may be more examples of non-determinism in CSE-TIR but this is what I found for now 👀