-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make Linear support overlapping input/axis names? #53
Comments
I think it's easiest to just always rename all axes to Is there a reason to not do this (e.g., performance issues)? |
It messes up FSDP, or at least it makes it so you have to specify that both Embed_in an Embed_out are sharded, which is a bit noisier |
I don't know how sharding works in Haliax. Would you mind explaining why it messes up sharding? |
well, "messes up" is a bit strong, but the key idea behind sharding in Haliax is mapping named axes to a device mesh axis (cf the tutorial https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz). Currently to set up FSDP, we do:
and this means that every "embed" axis in the model is sharded across the data axis of the device mesh. To add tensor parallelism, you'd do something like:
With your change, we'd have to do
which seems noisier. WDYT? |
They seem to be pretty much the same noisy to me, and I think it's fine to make that change. In the first one you need to have separate names for all your axes in a sequence of linear layers, which can be just as confusing. I think it ultimately comes down to needing a disjoint union of axes specs, not a union, and I don't think this is possible without renaming things. Perhaps one could create some kind of tree (or DAG) of axes that are derived from other axes and then automagically when sharding also shard any sub-axes, but that feels like overcomplicating things. |
Currently Haliax requires that all names in a single named array be unique. In general I think this is a good constraint. However, for Linear layers it's frequently a nuisance, since one often projects to something of the same shape, or you might want to keep the same name ("hidden").
So, it might be a good idea to support overlapping names. This will complicate the implementation quite a bit but simplify some juggling outside. I think this is worth the complexity?
Probably we'd rename overlapping "output" names to
${name}_out
and then rename them in the result back to${name}
. If we make this a contract, then you can use it to control sharding.The text was updated successfully, but these errors were encountered: