Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] CSE-TIR Pass - More deterministic behavior #10663

Merged
merged 6 commits into from
Mar 18, 2022

Conversation

AndrewZhaoLuo
Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo commented Mar 17, 2022

Running the CSE-TIR pass led to some variation in lowering to TIR from the same schedule multiple times. I tracked this down to iterating over an unordered map. I think having consistency in lowering to TIR is a desirable property to have so I made the offending code a little better in this regard.

Example test script:
https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1

^--- Running the above off main produces different hashes every time, running after this diff it produces the same hash.

There may be more examples of non-determinism in CSE-TIR but this is what I found for now 👀

@AndrewZhaoLuo AndrewZhaoLuo changed the title [TIR] CSE-TIR Pass - More deterministic Behavior [TIR] CSE-TIR Pass - More deterministic behavior Mar 17, 2022
@AndrewZhaoLuo
Copy link
Contributor Author

cc some folks from #9482: @masahi @FranckQC @mbs-octoml @jroesch

@masahi
Copy link
Member

masahi commented Mar 17, 2022

@AndrewZhaoLuo maybe we can also solve this by changing the comparison functor in

using ComputationTable = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;

@AndrewZhaoLuo
Copy link
Contributor Author

AndrewZhaoLuo commented Mar 17, 2022

Hmm, so the unordered map doesn't have an ordering so all it needs is some way of determining equality of elements.

An alternative would be std::map which does have a notion of ordering (and you can define a comparator). However, it's performance characteristics are different (usually self-balancing binary tree so O(log n) lookup instead of O(~1)).

In general though yeah, we need to be on the lookout for iterating over unordered collections cause it will be non-determinstic so might be good for maintainability to use std::map

I'll try refactoring unordered_map --> map sometime later, though need to read more about performance (it's already kind of slow when processing a statement with 100 common subexpressions).

@masahi
Copy link
Member

masahi commented Mar 17, 2022

Hmm, so the unordered map doesn't have an ordering

The actual order doesn't matter, does it? unordered_map doesn't store elements in a sorted order, but it does have some ordering internally determined by the comparison functor. And what we need is just a fixed order, not necessarily the sorted order, so I think unordered map should do the job.

@masahi
Copy link
Member

masahi commented Mar 17, 2022

ah sorry, I was confused with something, "but it does have some ordering internally determined by the comparison functor" is not correct.

@masahi
Copy link
Member

masahi commented Mar 17, 2022

@AndrewZhaoLuo I think we can merge this as is for now. The CI failed on the mac os, probably need to kick another job.

@masahi masahi merged commit b01e3fc into apache:main Mar 18, 2022
@FranckQC
Copy link
Contributor

FranckQC commented Mar 21, 2022

Hi,

Sorry I couldn't discuss this with you last Thursday/Friday. Thanks for improving the CSE!

Yes, indeed, iterating through the hash table will not necessary lead to accessing its elements always in the same order. That's because the iteration order depends on the hashes (and we use structural_hash for that), which themselves are based on pointers, so different run will give different addresses, which will give different hashes, and thus a different order in which elements are accessed.

I guess it is not an issue as far as the CSE is concerned, because after producing this hashtable, we use it to produce a vector of semantic entities (which are pairs <expr, count>).
The main loop of the CSE will later iterate through the elements of this vector.
But before that, this vector gets sorted by the size (called "complexity") of the expression (the first component).

So the order of the end result (the vector) is non-deterministic only for elements with the same size / complexity, for which it does not matter (again, only for the CSE) in which order we common them. (i.e. it should not add or remove opportunities for commoning more later). Said differently, this non-determinism should appear for orthogonal choices.
I guess as in:

Mem[i1] = (a+b)+ (c+d);
Mem[i2] = a+b;               // We can common this one before or after
Mem[i3] = c+d;               // this one, it does not matter for the CSE
Mem[i4] = (a+b) + (c+d) + e;

for which we could produce:

let cse_var_2 = a+b in
let cse_var_3 = c+d in
let cse_var_1 = cse_var_2 + cse_var_3 in
Mem[i1] = cse_var_1;
Mem[i2] = cse_var_2;
Mem[i3] = cse_var_3;
Mem[i4] = cse_var_1 + e;

Or the alternative version were (c+d) gets introduced before (a+b). In practice, what is chosen currently depends on the addresses of the free variables 'a', 'b', 'c' and 'd' because they are used for the hashes.

I did not think that having this level of non-determinism would be an issue, but you're clearly right that its better to always produce exactly the same resulting TIR code, that just makes testing easier. Sorry I did not think about it when I wrote it!
Your way of making it deterministic (by first copying the hashtable into an array, which you then sort using the actual syntax that the expressions represent) is perfectly fine to me.

According to the doc about structural_hash ( here : https://tvm.apache.org/docs/reference/api/python/ir.html ), it seems an alternative could have been to set its map_free_vars parameter to true, as that's the reason the hashes use pointers. If that was the only place were addresses are used in the hashes, it should make everything deterministic too. I'll see if that works too, out of curiosity.

Many thanks for having spotted that and for having fixed it!

Kind regards,

Franck

pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
* iterate through sorted keys

* masa comments -- simplify iteration

* test

* tests

* simplify vector construciton

* jostle ci
tkonolige pushed a commit that referenced this pull request Jun 9, 2022
…orm - avoids comparison of terms (#11574)

The CSE pass had been designed for potentially allowing comparisons (and commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where **the notion of being equivalent was customizable, and no assumption was made about it**. That means that the implementation of the equivalence test function `EquivalentTerms()` - which was at the moment just calling the syntactical equality test `EqualTerms()` - could be replaced later by a cleverer equality test.

However, having such a generic way of comparing elements meant that in the function `SyntacticToSemanticComputations()`, where we were going from a hashtable of syntactical entities to what I called a vector of "semantical entites" (which are just canonical forms/representants of classes of equivalence of terms), **the only way was to compare each pair**.
That resulted in a quadratic behavior of this function, but there was no way around it as in order to merge equivalent entities into their class of equivalence, we had to compare them.

**This PR essentially does the following:**

- When computing the classes of equivalences of terms (therefore transforming a ComputationTable (i.e. a hashtable) into a vector of classes of equivalence) : **instead of comparing each pair of terms, relies on a normalization procedure to obtain a normal form for each of them**.
That transforms a small part of the algorithm that was quadratic to n.logn. However, it's difficult to see improvements in practice, in particular for average sized programs, as that part was a "small" quadratic to a "big" n.logn (finding things in a hash-table, copying it to a vector, etc).
It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a complexity of ~O(3n + n.logn), so potential gains would only be expected for very large programs.

- Completely gives the user the possibility to turn ON/OFF the semantical comparisons of terms. It is turned OFF by default (as it's quite longer to compile with it ON, unsurprisingly), which means that by default, the equivalence coincides with the (syntactical) equality of terms.
    As the pass was written with the possibility to do these additional commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that completely, up to the Python user who can now turn that ON if he wants to. But again, it is OFF by default, so no real change on that.

To run it ON, simply do:
`with tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):`
before calling `build()`

- When this boolean is set to ON, it uses a simple implementation of the normalization function with equivalences that uses `arith::Analyzer::Simplify` as noted by in #10544 . Note that this is not a real normalization procedure as it is incomplete (i.e., it is not guarantee to converge to the normal form), but it is correct, and it works well with most properties : associativity of +, distributivity of * on +, etc.

- Clarifies and enhance the test base for the pass. In particular, it adds the tests that were written in #10544 but which did not make it through.

- Also add the test ( https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1 ) demonstrating the (older) non-deterministic lowering and put it into a proper test, as I found it useful for making sure that this does not happen again. It has been copied from #10663 and only slightly adapted (in particular for doing the comparison of hashes automatically instead of printing them and relying on a human to compare them).
Kathryn-cat pushed a commit to Kathryn-cat/tvm that referenced this pull request Jun 10, 2022
…orm - avoids comparison of terms (apache#11574)

The CSE pass had been designed for potentially allowing comparisons (and commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where **the notion of being equivalent was customizable, and no assumption was made about it**. That means that the implementation of the equivalence test function `EquivalentTerms()` - which was at the moment just calling the syntactical equality test `EqualTerms()` - could be replaced later by a cleverer equality test.

However, having such a generic way of comparing elements meant that in the function `SyntacticToSemanticComputations()`, where we were going from a hashtable of syntactical entities to what I called a vector of "semantical entites" (which are just canonical forms/representants of classes of equivalence of terms), **the only way was to compare each pair**.
That resulted in a quadratic behavior of this function, but there was no way around it as in order to merge equivalent entities into their class of equivalence, we had to compare them.

**This PR essentially does the following:**

- When computing the classes of equivalences of terms (therefore transforming a ComputationTable (i.e. a hashtable) into a vector of classes of equivalence) : **instead of comparing each pair of terms, relies on a normalization procedure to obtain a normal form for each of them**.
That transforms a small part of the algorithm that was quadratic to n.logn. However, it's difficult to see improvements in practice, in particular for average sized programs, as that part was a "small" quadratic to a "big" n.logn (finding things in a hash-table, copying it to a vector, etc).
It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a complexity of ~O(3n + n.logn), so potential gains would only be expected for very large programs.

- Completely gives the user the possibility to turn ON/OFF the semantical comparisons of terms. It is turned OFF by default (as it's quite longer to compile with it ON, unsurprisingly), which means that by default, the equivalence coincides with the (syntactical) equality of terms.
    As the pass was written with the possibility to do these additional commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that completely, up to the Python user who can now turn that ON if he wants to. But again, it is OFF by default, so no real change on that.

To run it ON, simply do:
`with tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):`
before calling `build()`

- When this boolean is set to ON, it uses a simple implementation of the normalization function with equivalences that uses `arith::Analyzer::Simplify` as noted by in apache#10544 . Note that this is not a real normalization procedure as it is incomplete (i.e., it is not guarantee to converge to the normal form), but it is correct, and it works well with most properties : associativity of +, distributivity of * on +, etc.

- Clarifies and enhance the test base for the pass. In particular, it adds the tests that were written in apache#10544 but which did not make it through.

- Also add the test ( https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1 ) demonstrating the (older) non-deterministic lowering and put it into a proper test, as I found it useful for making sure that this does not happen again. It has been copied from apache#10663 and only slightly adapted (in particular for doing the comparison of hashes automatically instead of printing them and relying on a human to compare them).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants