-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of MLPMixer #103
Conversation
Re-organizing the repo into CNNs and ViTs makes sense, but I would avoid introducing submodules. I would just have two folders and put all the includes in I like TensorCast.jl, but I want to avoid taking on a dep if we can. I'll have to go through and see if we can rewrite using standards ops without being too messy. Now, we got ViTs, exciting! |
I forgot to ask this on the last PR, but can you update the table in the README with the models from the previous PRs + this one? |
I don't think MLPMixer can be classified as a visual transformer, but maybe we could argue that this separation is fine since the spirit is the same |
Will do.
Yeah the MLPMixer one isn't that bad, it's just two places, but for the standard ViT and for other versions like LeViT and Swin it gets progressively worse.
Will do, I wasn't sure whether or not to do it because it wasn't tagged as a release yet xD.
Yeah they share quite a few things with the patches and embeddings design, and the original JAX repository released the code together so I think it's better to keep them together |
Only the main branch README will change. The docs will still reflect the tagged version.
True, but it doesn't really fit any category. Once the submodules are removed, this distinction will only affect our code organization and not the user interfaces. |
Updated README
That's odd, the CI seems to be failing only on Linux and I don't see why...the error message says the file doesn't exist but it very much does 🤨 |
Yeah some sort of weird renaming issue - I'd done it on the local but it didn't seem to have been pushed to the remote somehow. Fixed now |
Can you rebase with the latest changes? Shouldn't affect your work here. Also, I'll fix the README cause I need to fix the docs CI too. |
Yep, should be fine now. Checked the tests locally and they worked alright |
Bikeshedding the folder structure, I don't think there needs to be a premature assignment of MLP Mixer into a specific category (e.g. where would DeiT fit in under this scheme?). Keeping it at the top level or in an "other" directory would be fine. If/when we get attention-based models, those can get their own dir. |
Updated README
Updated README
I've make the change for MLPMixer to be put in an "other" directory, then. I think DeiT would slot into the ViT folder given that the paper explicitly refers to it as such (FAIR's repo released all their ViT-based models together too) but I get the reason for the MLPMixer contention so I've taken care of that |
Playing around with this model, I realised that the |
I have some suggestions which should resolve that issue. I've just been busy over the last few days, but I'll review both new PRs later this evening. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one major change here would be to avoid making MLPMixer
as a struct with a custom forward pass. What I recommend instead is to define a new patching layer (in a separate file) like
struct Patching{T<:Integer}
patch::T
end
@functor Patching
function (p::Patching)(x)
h, w, c, n = size(x)
hp, wp = h ÷ p.patch, w ÷ p.patch
xpatch = reshape(x, hp, p.patch, wp, p.patch, c, n)
return reshape(permutedims(xpatch, (1, 3, 5, 2, 4, 6), p.patch^2 * c, hp * wp, n)
end
Then MLPMixer
is nothing more than a Chain
. Part of the goal with this package's design is to illustrate how Flux's feature make it possible to build advanced models with minimal boilerplate. Reimplementing Chain
-like forward passes is something we want to avoid when possible.
Overall, the implementation looks correct; we just need some design iterations. Great job!
Will do this, but if you don't mind can I keep the TensorCast dep? I will remove it if there's a clear reason, but |
Personally, I find the einsum notation more confusing than the plain Julia, but of course, that's just my subjective opinion. Mainly, I'm hesitant to take on another dependency just to mimic Python. Especially since it seems to only serve indexing operations that Julia does quite well with |
There are a few einsum-related packages, so if we do adopt one it would ideally be relatively lightweight + performant (flexibility is not an issue since Python frameworks don't have much in their einsum implementations). Because that could take a minute, I vote to defer that discussion to the ViT PR since this one doesn't require any einsum ops. That will let us merge this one asap :) |
Okay @theabhirath does that sound like a reasonable plan? Use the non-einsum implementation here for now. Since this model is mostly there, this will let us merge without much back and forth. If we decide to take on an einsum dep in the ViT PR, we can update this model's code as well. |
That makes sense to me. I'll make the changes 👍🏽 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost done. A few minor changes and we can merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! There were a few things I missed on the last pass, sorry.
Also, FYI for the future: use |
Thank you so much! The formatting is still a pain 😅 Some sort of auto-formatter really needs to be around to ensure this kind of thing doesn't sneak under the radar |
Yeah I agree. There's a PR for Flux on this. I'm just waiting for that to settle on a style choice which I can duplicate here. |
This is an implementation of one of the many models in the wake of the ViT explosion, MLPMixer.
There's two things I wanted to clarify: