-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relax][Transform] Preserve param names in LiftTransformParams #16594
[Relax][Transform] Preserve param names in LiftTransformParams #16594
Conversation
The `relax.transform.LiftTransformParams` pass splits apart a relax function, extracting the steps that could be performed at compile-time. Prior to this commit, the transformed parameters were named `param0`, `param1`, and so on. This commit updates the `LiftTransformParams` pass to preserve any human-readable parameter names. The parameter names for the updated function are taken from the original parameter names, if no transformation is performed, or from the internal variable binding, if a transformation is applied. This implementation uses `LambdaLift` internally, relying on the changes made in apache#16306.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a substantial reorganization of LiftTransformParams and it's nicely modular. Thanks for the changes!
// A post-proc utility to reproduce the previous behavior of | ||
// LiftTransformParams. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO (total nitpick) this is not phrasing we should really have in the codebase itself (as opposed to in the PR description) since someone reading the code will not know what the "previous" version was. We could probably just leave out the bit about the "previous behavior."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I completely agree with you there. I should have kept references to "previous behavior" in the commit message instead, where there's a clear before/after. Updated the comment to describe this as the third step in LiftTransformParams
, and to include the names of the other two steps.
// LiftTransformParams. | ||
// | ||
// 1. Partition each function into a compile-time and run-time | ||
// lambda functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// lambda functions. | |
// lambda function. |
Typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and fixed!
std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_; | ||
// The plan of lifting the transform params | ||
LiftTransformParamsInfoPlan lift_plan_; | ||
// Adapted from https://stackoverflow.com/a/2072890 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. At some point, this may be useful as a member function of tvm::String
, similar to how python has str.endswith
and str.startswith
.
|
||
/*! \brief Bindings that can be lifted out into a pre-processing | ||
* | ||
* - All bindings in `liftable_bindings` are suitable for use in a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is assume liftable_bindings
is an old name that should be updated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and that is exactly correct. I replaced all references, but didn't grep the comments for spots that needed to be updated. Fixed.
} | ||
} | ||
|
||
// Cond 5. Do not lift declarations of external functions | ||
if (binding->value.as<relax::ExternFuncNode>()) { | ||
can_lift = false; | ||
if (value.as<relax::ExternFuncNode>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious when we ever have an ExternFunc as the RHS of a binding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It happened more often prior to #15900, but rarely happens now.
At some point, I'd like to allow a relax function to specify when tvm::runtime::Registry::Get
is invoked. It's legal for a R.ExternFunc
to appear within the IRModule::functions
table, on the RHS of a binding, or inline for a relax::CallNode::op
. Currently, all three of those are resolved when the relax::VirtualMachine
is instantiated, but they could be used to represent binding at instantiation-time, at function start-time, and at function execution-time. In that case, the ExternFunc would become more common as the RHS of a binding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess ExternFuncs are permitted to be used in that way per the spec as first-class values, so this is good to handle.
@@ -247,7 +247,8 @@ class BindingCanonicalizer : public ExprMutator { | |||
} | |||
|
|||
Expr VisitExpr_(const VarNode* var) override { | |||
Var new_var = GetRef<Var>(var); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was there any bug you encountered due to not doing this visit before? If so, that would be a good regression test to throw in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this was definitely required during initial development, and was intended to handle cases where a replacement in ExprMutator::var_remap_
occurs outside of the planned replacements in the ReplacementPlan.replace_usage_
. This can happen if the removal of trivial bindings triggers InferStructInfo
in a downstream operation. However, that doesn't reproduce the bug, nor does that bug occur in anything on my local dev branch.
I'm removing this line from the PR, and can make a separate PR (with test case!) if/when the bug resurfaces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's necessary to remove, especially since it seems like a correct change. I'm glad that there wasn't a bug that you encountered because of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double-checking, I think the reason this line is no longer necessary is due to the ExprMutator::VisitExpr_(new_var.get());
at the end of the function. It may have been necessary during some of the development steps, but is no longer necessary. Thinking through the reasoning:
- The
var_remap_
would apply ifVisitVarDef
- Canonicalization updates a variable in
VisitVarDef
if theplan_.replace_binding
is populated. - From (1) and (2), anything that would be populated in
var_remap_
would also be populated inplan_.replace_bindings
- Whenever we populate
plan_.replace_binding
, we also populateplan_.replace_usage
. - From (3) and (4), anything that would be replaced by
var_remap_
could instead be replaced byplan_.replace_usage
.
So, I think it isn't necessary to have two calls to ExprMutator::VisitExpr_(const Varnode*)
in the implementation.
The
relax.transform.LiftTransformParams
pass splits apart a relax function, extracting the steps that could be performed at compile-time. Prior to this commit, the transformed parameters were namedparam0
,param1
, and so on.This commit updates the
LiftTransformParams
pass to preserve any human-readable parameter names. The parameter names for the updated function are taken from the original parameter names, if no transformation is performed, or from the internal variable binding, if a transformation is applied. This implementation usesLambdaLift
internally, relying on the changes made in#16306.