Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of MLPMixer #103

Merged
merged 24 commits into from
Feb 4, 2022
Merged

Implementation of MLPMixer #103

merged 24 commits into from
Feb 4, 2022

Conversation

theabhirath
Copy link
Member

This is an implementation of one of the many models in the wake of the ViT explosion, MLPMixer.

There's two things I wanted to clarify:

  1. I added TensorCast as a dep because einops like operations are very commonplace in ViT model implementations and I thought it would be easier to work with that. If you think that I should change that up somehow or maybe just use standard array operations, lemme know and I'll revert - it's very painful though, I tried already 🥲
  2. I also made some changes to the organisation of the repo (might be a bit of an understatement) because I thought it would be better to separate out the CNN models and the ViT models (I have some others in the works) but lemme know if you think it's overkill and I'll revert to the original repo structure.

@darsnack
Copy link
Member

Re-organizing the repo into CNNs and ViTs makes sense, but I would avoid introducing submodules. I would just have two folders and put all the includes in Metalhead.jl (grouped by type to make it look nice).

I like TensorCast.jl, but I want to avoid taking on a dep if we can. I'll have to go through and see if we can rewrite using standards ops without being too messy.

Now, we got ViTs, exciting!

@darsnack
Copy link
Member

I forgot to ask this on the last PR, but can you update the table in the README with the models from the previous PRs + this one?

@CarloLucibello
Copy link
Member

I don't think MLPMixer can be classified as a visual transformer, but maybe we could argue that this separation is fine since the spirit is the same

@theabhirath
Copy link
Member Author

Re-organizing the repo into CNNs and ViTs makes sense, but I would avoid introducing submodules. I would just have two folders and put all the includes in Metalhead.jl (grouped by type to make it look nice).

Will do.

I like TensorCast.jl, but I want to avoid taking on a dep if we can. I'll have to go through and see if we can rewrite using standards ops without being too messy.

Yeah the MLPMixer one isn't that bad, it's just two places, but for the standard ViT and for other versions like LeViT and Swin it gets progressively worse.

I forgot to ask this on the last PR, but can you update the table in the README with the models from the previous PRs + this one?

Will do, I wasn't sure whether or not to do it because it wasn't tagged as a release yet xD.

I don't think MLPMixer can be classified as a visual transformer, but maybe we could argue that this separation is fine since the spirit is the same

Yeah they share quite a few things with the patches and embeddings design, and the original JAX repository released the code together so I think it's better to keep them together

@darsnack
Copy link
Member

Will do, I wasn't sure whether or not to do it because it wasn't tagged as a release yet xD.

Only the main branch README will change. The docs will still reflect the tagged version.

I don't think MLPMixer can be classified as a visual transformer, but maybe we could argue that this separation is fine since the spirit is the same

True, but it doesn't really fit any category. Once the submodules are removed, this distinction will only affect our code organization and not the user interfaces.

@theabhirath
Copy link
Member Author

That's odd, the CI seems to be failing only on Linux and I don't see why...the error message says the file doesn't exist but it very much does 🤨

@darsnack darsnack closed this Jan 29, 2022
@darsnack darsnack reopened this Jan 29, 2022
@theabhirath
Copy link
Member Author

Yeah some sort of weird renaming issue - I'd done it on the local but it didn't seem to have been pushed to the remote somehow. Fixed now

@darsnack
Copy link
Member

darsnack commented Jan 29, 2022

Can you rebase with the latest changes? Shouldn't affect your work here. Also, I'll fix the README cause I need to fix the docs CI too.

@theabhirath
Copy link
Member Author

Yep, should be fine now. Checked the tests locally and they worked alright

@ToucheSir
Copy link
Member

Bikeshedding the folder structure, I don't think there needs to be a premature assignment of MLP Mixer into a specific category (e.g. where would DeiT fit in under this scheme?). Keeping it at the top level or in an "other" directory would be fine. If/when we get attention-based models, those can get their own dir.

@theabhirath
Copy link
Member Author

Bikeshedding the folder structure, I don't think there needs to be a premature assignment of MLP Mixer into a specific category (e.g. where would DeiT fit in under this scheme?). Keeping it at the top level or in an "other" directory would be fine. If/when we get attention-based models, those can get their own dir.

I've make the change for MLPMixer to be put in an "other" directory, then. I think DeiT would slot into the ViT folder given that the paper explicitly refers to it as such (FAIR's repo released all their ViT-based models together too) but I get the reason for the MLPMixer contention so I've taken care of that

@theabhirath
Copy link
Member Author

Playing around with this model, I realised that the show functionality doesn't work as expected because I've defined a custom model instead of writing the layers with say Chain or SkipConnection, which are already present in Flux and thus likely implement show on their own. Is there an easier way to get it done than having to write it layer by layer for the custom operations? There's also the use of TensorCast, which I imagine will complicate things slightly

@darsnack
Copy link
Member

darsnack commented Feb 2, 2022

I have some suggestions which should resolve that issue. I've just been busy over the last few days, but I'll review both new PRs later this evening.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one major change here would be to avoid making MLPMixer as a struct with a custom forward pass. What I recommend instead is to define a new patching layer (in a separate file) like

struct Patching{T<:Integer}
  patch::T
end

@functor Patching

function (p::Patching)(x)
  h, w, c, n = size(x)
  hp, wp = h ÷ p.patch, w ÷ p.patch
  xpatch = reshape(x, hp, p.patch, wp, p.patch, c, n)

  return reshape(permutedims(xpatch, (1, 3, 5, 2, 4, 6), p.patch^2 * c, hp * wp, n)
end

Then MLPMixer is nothing more than a Chain. Part of the goal with this package's design is to illustrate how Flux's feature make it possible to build advanced models with minimal boilerplate. Reimplementing Chain-like forward passes is something we want to avoid when possible.

Overall, the implementation looks correct; we just need some design iterations. Great job!

src/Metalhead.jl Outdated Show resolved Hide resolved
src/vit-like/mlpmixer.jl Outdated Show resolved Hide resolved
src/vit-like/mlpmixer.jl Outdated Show resolved Hide resolved
@theabhirath
Copy link
Member Author

I think one major change here would be to avoid making MLPMixer as a struct with a custom forward pass. What I recommend instead is to define a new patching layer (in a separate file) like

struct Patching{T<:Integer}
  patch::T
end

@functor Patching

function (p::Patching)(x)
  h, w, c, n = size(x)
  hp, wp = h ÷ p.patch, w ÷ p.patch
  xpatch = reshape(x, hp, p.patch, wp, p.patch, c, n)

  return reshape(permutedims(xpatch, (1, 3, 5, 2, 4, 6), p.patch^2 * c, hp * wp, n)
end

Then MLPMixer is nothing more than a Chain. Part of the goal with this package's design is to illustrate how Flux's feature make it possible to build advanced models with minimal boilerplate. Reimplementing Chain-like forward passes is something we want to avoid when possible.

Will do this, but if you don't mind can I keep the TensorCast dep? I will remove it if there's a clear reason, but einops notation has become so ubiquitous across model implementations in Python frameworks that I thought it would be a lot more intuitive for it to be the same way here

@darsnack
Copy link
Member

darsnack commented Feb 3, 2022

Personally, I find the einsum notation more confusing than the plain Julia, but of course, that's just my subjective opinion.

Mainly, I'm hesitant to take on another dependency just to mimic Python. Especially since it seems to only serve indexing operations that Julia does quite well with reshape. If it turns out that ViT can be made faster by using TensorCast, then that seems like a good reason to take on the dep. Maybe other contributors can weigh in (cc @ToucheSir).

@ToucheSir
Copy link
Member

There are a few einsum-related packages, so if we do adopt one it would ideally be relatively lightweight + performant (flexibility is not an issue since Python frameworks don't have much in their einsum implementations). Because that could take a minute, I vote to defer that discussion to the ViT PR since this one doesn't require any einsum ops. That will let us merge this one asap :)

@darsnack
Copy link
Member

darsnack commented Feb 3, 2022

Okay @theabhirath does that sound like a reasonable plan? Use the non-einsum implementation here for now. Since this model is mostly there, this will let us merge without much back and forth.

If we decide to take on an einsum dep in the ViT PR, we can update this model's code as well.

@theabhirath
Copy link
Member Author

That makes sense to me. I'll make the changes 👍🏽

@theabhirath theabhirath requested a review from darsnack February 3, 2022 07:08
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost done. A few minor changes and we can merge.

src/other/mlpmixer.jl Outdated Show resolved Hide resolved
src/other/mlpmixer.jl Outdated Show resolved Hide resolved
src/other/mlpmixer.jl Outdated Show resolved Hide resolved
src/other/mlpmixer.jl Outdated Show resolved Hide resolved
src/other/mlpmixer.jl Outdated Show resolved Hide resolved
@theabhirath theabhirath requested a review from darsnack February 3, 2022 16:46
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! There were a few things I missed on the last pass, sorry.

src/Metalhead.jl Outdated Show resolved Hide resolved
src/other/mlpmixer.jl Outdated Show resolved Hide resolved
src/other/mlpmixer.jl Outdated Show resolved Hide resolved
test/other.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member

darsnack commented Feb 3, 2022

Also, FYI for the future: use git rebase instead of git merge when pulling in upstream changes to your PR branch. git merge will show all the upstream changes in the diff, making it hard to parse what's actually changed vs upstream and what's just showing as changed even though it's the same.

src/Metalhead.jl Outdated Show resolved Hide resolved
src/other/mlpmixer.jl Outdated Show resolved Hide resolved
@darsnack darsnack merged commit 1eb8a51 into FluxML:master Feb 4, 2022
@theabhirath
Copy link
Member Author

Thank you so much! The formatting is still a pain 😅 Some sort of auto-formatter really needs to be around to ensure this kind of thing doesn't sneak under the radar

@theabhirath theabhirath deleted the mlpmixer branch February 4, 2022 03:01
@darsnack
Copy link
Member

darsnack commented Feb 4, 2022

Yeah I agree. There's a PR for Flux on this. I'm just waiting for that to settle on a style choice which I can duplicate here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants