Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Transform] Preserve param names in LiftTransformParams #16594

Conversation

Lunderberg
Copy link
Contributor

The relax.transform.LiftTransformParams pass splits apart a relax function, extracting the steps that could be performed at compile-time. Prior to this commit, the transformed parameters were named param0, param1, and so on.

This commit updates the LiftTransformParams pass to preserve any human-readable parameter names. The parameter names for the updated function are taken from the original parameter names, if no transformation is performed, or from the internal variable binding, if a transformation is applied. This implementation uses LambdaLift internally, relying on the changes made in
#16306.

The `relax.transform.LiftTransformParams` pass splits apart a relax
function, extracting the steps that could be performed at
compile-time.  Prior to this commit, the transformed parameters were
named `param0`, `param1`, and so on.

This commit updates the `LiftTransformParams` pass to preserve any
human-readable parameter names.  The parameter names for the updated
function are taken from the original parameter names, if no
transformation is performed, or from the internal variable binding, if
a transformation is applied.  This implementation uses `LambdaLift`
internally, relying on the changes made in
apache#16306.
Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a substantial reorganization of LiftTransformParams and it's nicely modular. Thanks for the changes!

Comment on lines 424 to 425
// A post-proc utility to reproduce the previous behavior of
// LiftTransformParams.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO (total nitpick) this is not phrasing we should really have in the codebase itself (as opposed to in the PR description) since someone reading the code will not know what the "previous" version was. We could probably just leave out the bit about the "previous behavior."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I completely agree with you there. I should have kept references to "previous behavior" in the commit message instead, where there's a clear before/after. Updated the comment to describe this as the third step in LiftTransformParams, and to include the names of the other two steps.

// LiftTransformParams.
//
// 1. Partition each function into a compile-time and run-time
// lambda functions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// lambda functions.
// lambda function.

Typo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and fixed!

std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
// The plan of lifting the transform params
LiftTransformParamsInfoPlan lift_plan_;
// Adapted from https://stackoverflow.com/a/2072890
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. At some point, this may be useful as a member function of tvm::String, similar to how python has str.endswith and str.startswith.


/*! \brief Bindings that can be lifted out into a pre-processing
*
* - All bindings in `liftable_bindings` are suitable for use in a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is assume liftable_bindings is an old name that should be updated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and that is exactly correct. I replaced all references, but didn't grep the comments for spots that needed to be updated. Fixed.

}
}

// Cond 5. Do not lift declarations of external functions
if (binding->value.as<relax::ExternFuncNode>()) {
can_lift = false;
if (value.as<relax::ExternFuncNode>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious when we ever have an ExternFunc as the RHS of a binding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It happened more often prior to #15900, but rarely happens now.

At some point, I'd like to allow a relax function to specify when tvm::runtime::Registry::Get is invoked. It's legal for a R.ExternFunc to appear within the IRModule::functions table, on the RHS of a binding, or inline for a relax::CallNode::op. Currently, all three of those are resolved when the relax::VirtualMachine is instantiated, but they could be used to represent binding at instantiation-time, at function start-time, and at function execution-time. In that case, the ExternFunc would become more common as the RHS of a binding.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess ExternFuncs are permitted to be used in that way per the spec as first-class values, so this is good to handle.

@@ -247,7 +247,8 @@ class BindingCanonicalizer : public ExprMutator {
}

Expr VisitExpr_(const VarNode* var) override {
Var new_var = GetRef<Var>(var);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there any bug you encountered due to not doing this visit before? If so, that would be a good regression test to throw in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this was definitely required during initial development, and was intended to handle cases where a replacement in ExprMutator::var_remap_ occurs outside of the planned replacements in the ReplacementPlan.replace_usage_. This can happen if the removal of trivial bindings triggers InferStructInfo in a downstream operation. However, that doesn't reproduce the bug, nor does that bug occur in anything on my local dev branch.

I'm removing this line from the PR, and can make a separate PR (with test case!) if/when the bug resurfaces.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's necessary to remove, especially since it seems like a correct change. I'm glad that there wasn't a bug that you encountered because of it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double-checking, I think the reason this line is no longer necessary is due to the ExprMutator::VisitExpr_(new_var.get()); at the end of the function. It may have been necessary during some of the development steps, but is no longer necessary. Thinking through the reasoning:

  1. The var_remap_ would apply if VisitVarDef
  2. Canonicalization updates a variable in VisitVarDef if the plan_.replace_binding is populated.
  3. From (1) and (2), anything that would be populated in var_remap_ would also be populated in plan_.replace_bindings
  4. Whenever we populate plan_.replace_binding, we also populate plan_.replace_usage.
  5. From (3) and (4), anything that would be replaced by var_remap_ could instead be replaced by plan_.replace_usage.

So, I think it isn't necessary to have two calls to ExprMutator::VisitExpr_(const Varnode*) in the implementation.

@Lunderberg Lunderberg merged commit e715814 into apache:main Feb 23, 2024
21 checks passed
@Lunderberg Lunderberg deleted the relax_preserve_names_in_lift_transform_params branch February 23, 2024 14:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants