-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform] Modify FuseTIR pass to propagate buffer attributes #17075
[Transform] Modify FuseTIR pass to propagate buffer attributes #17075
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the fix overall, thank you! It looks like there's some implicit assumptions about the Relax/TIR being consumed, which may not hold in general.
src/relax/transform/fuse_tir.cc
Outdated
/*! \brief The IRModule */ | ||
const IRModule& mod_; | ||
// size_t call_num_inputs_ = -1; | ||
Map<Var, tir::Buffer> relax_to_tir_var_map_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This data structure assumes that there is a 1:1 mapping from relax::Var
to tir::Buffer
across the entire fused function. This would have incorrect results for cases where the same tensor is used as multiple arguments (e.g. R.add(A, A)
), or where the same tensor is used as an argument to more than one function (e.g. The tensor A
corresponds to two different TIR buffers in the sequence mean = R.mean(A); norm = R.sqrt(mean); A_norm = R.divide(A, norm)
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I did think about this issue, but I assumed that even though the same relax var might map to different buffers, it should have the same buffer attributes (since its source is the same relax var). I've also added a validation ICHECK to verify that the buffer attributes match (using structural equal).
I've also added a testcase to verify this use case as suggested in the below comment.
src/relax/transform/fuse_tir.cc
Outdated
const auto& tir_var = Downcast<tir::Var>(tir_args[i]); | ||
if (i < num_inputs) { | ||
const auto& relax_var = Downcast<Var>(relax_args[i]); | ||
relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The buffer_map
does not necessarily contain an entry for tir_var
. For example, the relax_var
could have PrimStructInfo
to pass a primitive scalar to the TIR funciton. Even if relax_var
has TensorStructInfo
, the TIR function may treat the DLTensor*
as an opaque pointer, passing it to a PackedFunc
without having an entry in the buffer_map
.
The best way to handle these cases is to wrap this line in a if(auto tir_buffer = buffer_map.Get(tir_var))
conditional, and then use tir_buffer.value()
inside the conditional instead of buffer_map[tir_var]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch. I did not consider this case. Fixed.
src/relax/transform/fuse_tir.cc
Outdated
for (size_t i = 0; i < tir_args.size(); ++i) { | ||
const auto& tir_var = Downcast<tir::Var>(tir_args[i]); | ||
if (i < num_inputs) { | ||
const auto& relax_var = Downcast<Var>(relax_args[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This Downcast<Var>
is not guaranteed to work. While the normalizer will pull most relax.Var
instances out to their own variable binding, R.const
arguments may still appear inline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, yes constants are possible. I've updated output as a map from Expr
to Buffer
instead of Var
.
cls = Before | ||
with R.dataflow(): | ||
w = R.call_tir( | ||
cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test case for incompatible usage of a single Relax var? As currently written, we could have a single Relax variable that is used in two separate R.call_tir
statements, where the function being called imposes different restrictions on it. For example, if x
were used in cls.add1
, which requires axis_separators=[1]
, and cls.add2
, which requires axis_separators=[]
. We should be able to identify this case and raise an error when it occurs.
(Ideally, that should never happen, but this would be the last point at which we'd have enough information to catch this failure mode at compile-time.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added the test case to check possible inconsistencies.
Arguments of a fused TIR PrimFunc generated from a fused relax function do not retain all the buffer attributes from their original PrimFuncs as the buffers are created from the StructInfo of the Relax vars. This patch collects a mapping of relax vars to its corresponding TIR buffers in a fused relax function and uses that info to propagate its buffer attributes such as `axis_separators` and `storage_scope`
643089d
to
bac24e1
Compare
@Lunderberg I think I've addressed your comments. When you get a chance, could you please take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making the changes, and looks good!
Thanks for taking the time to review and provide feedback. |
I found that this MR would change |
Arguments of a fused TIR PrimFunc generated from a fused relax function do not retain all the buffer attributes from their original PrimFuncs as the buffers are created from the StructInfo of the Relax vars. This patch collects a mapping of relax vars to its corresponding TIR buffers in a fused relax function and uses that info to propagate its buffer attributes such as
axis_separators
andstorage_scope