-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SLM] Allow modules to define pre-processing of weights #16785
[SLM] Allow modules to define pre-processing of weights #16785
Conversation
If there is a trivial binding of `Var = DataflowVar`, but the non-dataflow variable is never used outside the dataflow block in which is is declared, then we should keep the name of the upstream `DataflowVar`, as it is more likely to be the human-readable name (e.g. a function parameter).
Follow-up to apache#16777, add unit tests demonstrating desired behavior.
…r_16783' into slm_add_unit_tests_for_nn_exporter
Prior to this commit, the weights used by `nn.Module` instances were required to be `nn.Parameter` instances. This commit allows the weights to instead be `nn.Tensor` instances, defined in terms of other `nn.Parameter` weights. This allows a model to define both the original weights that would be present in an external checkpoint (e.g. a Pytorch or Safetensors file), and the pre-processing that should be performed on those weights. This is a re-implementation of apache#16757, which was reverted in apache#16777. The re-implementation preserves the handling of dynamic shapes specified as python strings, enabling the test cases that were added in apache#16784.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not fully versed on SLM but I did not see anything objectionable in these changes. Some of the TODOs would be good to keep an eye on.
# TODO(Lunderberg): Make this easier to call. Infering | ||
# struct info for a nested expression should be doable in | ||
# a free function, without requiring an active | ||
# BlockBuilder and an active FunctionFrame. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah... we might be able to decouple some of the functionality and allow for calling things like that separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Long-term, I think it would be nice to distinguish between local struct inference and non-local struct inference. The local inference could be applied when a relax object is constructed, which would avoid the current two-phase initialization of relax objects. Since this step can only perform local struct inference, which would be applied by default, this entire conditional could be removed.
There's some kinks that would need to be worked out first. Some of the struct inference for tensor operations currently throw errors a bit more than I think they should. (e.g. If R.matmul
throws an exception if the arguments are not R.Tensor
. If the arguments are R.Object
, the exception is still thrown, even though R.Tensor
is a subtype of R.Object
.) These fallbacks would probably get more exercise with local inference, as there may be less information available.
# TODO(Lunderberg): Make a `ir.transform.ConvertSSA`, | ||
# similar to the existing `tir.transform.ConvertSSA`, | ||
# that converts an entire module to SSA, including TIR | ||
# variable definitions used in either TIR or Relax. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would this conversion do on the Relax side? I thought vars already had exactly one point of definition in Relax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both Relax and TIR require SSA to be well-formed. However, there's a number of cases where a module could be unambiguously converted to SSA. (e.g. Two functions use the same relax.Var
as a parameter, which can be fixed by substituting a new variable in one of the functions.)
So, it wouldn't be a pass that would be called directly by end users, but would be for internal use. If a pass is most easily written in a way that results in the same symbolic variable occurring in multiple different functions, then this would be used as a post-processing pass. (e.g. Apply BindSymbolicVars
to one variable in a function, then save the result as a new function in the same IRModule. Useful, but would duplicate all other symbolic variables.)
new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in param.shape] | ||
var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr | ||
new_shape = [_normalize_dim(dim) for dim in param._shape] | ||
# var_cls = rx.DataflowVar if mode == "packed" else rx.Var |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this line meant to be used at any point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, that was a test during dev work. Removing the commented-out var_cls
line.
@@ -676,12 +676,31 @@ def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = None) | |||
result : Tensor | |||
The transposed result. | |||
""" | |||
|
|||
# TODO(Lunderberg): This is a more extensive auto-naming than | |||
# intended here. Is this still worth it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect these chains of definitions to be deep? If they can be, this might be undesirable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Long-term, I want to move this automatic naming from the nn.Module
side to the Relax side, since it could then be performed after removal of trivial bindings. I don't expect these chains to be deep, as it only tracks trivial bindings. The trivial binding from the Relax function parameter to the parameter's param._expr
field should be the only one that would be tracked.
# Relax variable names may contain '.' even though it | ||
# cannot be expressed in TVMScript. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this is something we should just check for and prohibit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could go either way. It's nice to have the 1:1 mapping between Relax and TVMScript, which would forbid the period within a relax variable name. However, it's also nice to have a 1:1 mapping between a Relax function parameter and weight tensor's name in a pytorch or safetensor file, which are usually written with a period in the name.
I've done some experimenting, and I think this use case may be handled with Pros:
Cons:
|
Prior to this commit, the weights used by
nn.Module
instances were required to benn.Parameter
instances. This commit allows the weights to instead benn.Tensor
instances, defined in terms of othernn.Parameter
weights. This allows a model to define both the original weights that would be present in an external checkpoint (e.g. a Pytorch or Safetensors file), and thepre-processing that should be performed on those weights.
This is a re-implementation of #16757, which was reverted in #16777. The re-implementation preserves the handling of dynamic shapes specified as python strings, enabling the test cases that were added in #16784.