Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of MLPMixer #103

Merged
merged 24 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
42726a1
Initial commit for MLP mixer
theabhirath Jan 29, 2022
0ad2db9
Updated directory structure
theabhirath Jan 29, 2022
e2a3e4a
Rename test/ConvNets.jl to test/convnets.jl
theabhirath Jan 29, 2022
f4e71b9
Initial commit for MLP mixer
theabhirath Jan 29, 2022
4a3d82f
Updated directory structure
theabhirath Jan 29, 2022
4116add
Rename test/ConvNets.jl to test/convnets.jl
theabhirath Jan 29, 2022
0402ad9
Merge branch 'mlpmixer' of https://github.com/theabhirath/Metalhead.j…
theabhirath Jan 29, 2022
3c61e4b
Update runtests.jl
theabhirath Jan 29, 2022
3b58716
Initial commit for MLP mixer
theabhirath Jan 29, 2022
7a72ef4
Updated directory structure
theabhirath Jan 29, 2022
bbf0cdf
Rename test/ConvNets.jl to test/convnets.jl
theabhirath Jan 29, 2022
bb49697
Initial commit for MLP mixer
theabhirath Jan 29, 2022
b5382dd
Updated directory structure
theabhirath Jan 29, 2022
de7ebf1
Rename test/ConvNets.jl to test/convnets.jl
theabhirath Jan 29, 2022
d2ff28b
Update runtests.jl
theabhirath Jan 29, 2022
624c539
Updated MLPMixer category
theabhirath Jan 30, 2022
7adccec
Merge branch 'mlpmixer' of https://github.com/theabhirath/Metalhead.j…
theabhirath Jan 30, 2022
d7933b1
Clean up files
theabhirath Jan 31, 2022
ef6030c
Trimmed struct definition for MLPMixer model
theabhirath Feb 2, 2022
3b7e421
Cleaned up MLPMixer implementation
theabhirath Feb 3, 2022
04c78c6
Update Metalhead.jl
theabhirath Feb 3, 2022
eb45dee
Cleaned up API for model
theabhirath Feb 3, 2022
cabe8aa
Minor formatting tweaks
theabhirath Feb 4, 2022
44de174
Apply suggestions from code review
darsnack Feb 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"

[compat]
BSON = "0.3.2"
Flux = "0.12"
Functors = "0.2"
julia = "1.4"
NNlib = "0.7.34"
julia = "1.4"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

[publish]
title = "Metalhead.jl"
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
| DenseNet-161 | [`DenseNet161`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.DenseNet161.html) | N |
| DenseNet-169 | [`DenseNet169`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.DenseNet169.html) | N |
| DenseNet-201 | [`DenseNet201`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.DenseNet201.html) | N |
| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNeXt.html) | N |
| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv2.html) | N |
| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv3.html) | N |
| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MLPMixer.html) | N |

## Getting Started

Expand Down
37 changes: 25 additions & 12 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,23 @@ using Flux: outputsize, Zygote
using Functors
using BSON
using Artifacts, LazyArtifacts
using TensorCast
using Statistics

import Functors

# Models
include("utilities.jl")
include("alexnet.jl")
include("vgg.jl")
include("resnet.jl")
include("googlenet.jl")
include("inception.jl")
include("squeezenet.jl")
include("densenet.jl")
include("resnext.jl")
include("mobilenet.jl")

# CNN models
include("convnets/alexnet.jl")
include("convnets/vgg.jl")
include("convnets/inception.jl")
include("convnets/googlenet.jl")
include("convnets/resnet.jl")
include("convnets/resnext.jl")
include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/mobilenet.jl")

export AlexNet,
VGG, VGG11, VGG13, VGG16, VGG19,
Expand All @@ -30,8 +33,18 @@ export AlexNet,

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
:MobileNetv2, :MobileNetv3)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
:MobileNetv2, :MobileNetv3)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
darsnack marked this conversation as resolved.
Show resolved Hide resolved
end

# ViT-like models
include("vit-like/mlpmixer.jl")

export MLPMixer

# use Flux._big_show to pretty print large models
for T in (:MLPMixer,)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
theabhirath marked this conversation as resolved.
Show resolved Hide resolved
end

end # module
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
87 changes: 87 additions & 0 deletions src/vit-like/mlpmixer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Utility function for creating a residual block with LayerNorm before the residual connection
residualprenorm(planes, fn) = SkipConnection(Chain(fn, LayerNorm(planes)), +)

# Utility function for 1D convolution
conv1d(inplanes, outplanes, activation) = Conv((1, ), inplanes => outplanes, activation)
theabhirath marked this conversation as resolved.
Show resolved Hide resolved

"""
feedforward(planes, expansion_factor = 4, dropout = 0., dense = Dense)

Feedforward block in the MLPMixer architecture.
([reference](https://arxiv.org/pdf/2105.01601)).

# Arguments
`planes`: Number of dimensions in the input and output.
`expansion_factor`: Determines the number of dimensions in the intermediate layer.
`activation`: Activation function to use.
`dropout`: Dropout rate.
`dense`: Type of dense layer to use in the feedforward block.
"""
function feedforward(planes, expansion_factor = 4, dropout = 0., dense = Dense)
Chain(dense(planes, planes * expansion_factor, gelu),
Dropout(dropout),
dense(planes * expansion_factor, planes, gelu),
Dropout(dropout))
end

struct MLPMixer
channels
planes
patch_size
num_patches
token_mix
channel_mix
layers
nclasses
end

"""
MLPMixer(; image_size = 256, channels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000)

Creates a model with the MLPMixer architecture.
([reference](https://arxiv.org/pdf/2105.01601)).

# Arguments
- `image_size`: Size of the input image.
- `channels`: Number of channels in the input image.
- `patch_size`: Size of each patch fed into the network.
- `planes`: Number of dimensions in every layer after the patch expansion layer.
- `depth`: Number of layers in the network.
- `expansion_factor`: Determines the number of dimensions in the intermediate layers.
- `dropout`: Dropout rate in the feedforward blocks.
- `nclasses`: Number of classes in the output.
"""
function MLPMixer(; image_size = 256, channels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000)
@assert (image_size % patch_size) == 0 "image size must be divisible by patch size"

num_patches = (image_size ÷ patch_size) ^ 2
token_mix = conv1d
channel_mix = Dense

layers = [Chain(residualprenorm(planes, feedforward(num_patches, expansion_factor,
dropout, token_mix)),
residualprenorm(planes, feedforward(planes, expansion_factor, dropout,
channel_mix)),) for _ in 1:depth]

MLPMixer(channels,
planes,
patch_size,
num_patches,
token_mix,
channel_mix,
layers,
nclasses)
end

function (m::MLPMixer)(x)
p = m.patch_size
@cast x[(h2, w2, c), (h, w), b] := x[(h, h2), (w, w2), c, b] h2 in 1:p, w2 in 1:p
x = Dense((m.patch_size ^ 2) * m.channels, m.planes)(x)
x = Chain(LayerNorm(m.planes), m.layers...)(x)
@reduce x[b, c] := mean(n) x[b, n, c]
theabhirath marked this conversation as resolved.
Show resolved Hide resolved
x = Dense(m.planes, m.nclasses)(x)
end

@functor MLPMixer
127 changes: 127 additions & 0 deletions test/convnets.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
using Metalhead, Test
using Flux

# PRETRAINED_MODELS = [(VGG19, false), ResNet50, GoogLeNet, DenseNet121, SqueezeNet]
PRETRAINED_MODELS = []

@testset "AlexNet" begin
model = AlexNet()
@test size(model(rand(Float32, 256, 256, 3, 2))) == (1000, 2)
@test_throws ArgumentError AlexNet(pretrain = true)
@test_skip gradtest(model, rand(Float32, 256, 256, 3, 2))
end

@testset "VGG" begin
@testset "$model(BN=$bn)" for model in [VGG11, VGG13, VGG16, VGG19], bn in [true, false]
imsize = (224, 224)
m = model(batchnorm = bn)

@test size(m(rand(Float32, imsize..., 3, 2))) == (1000, 2)
if (model, bn) in PRETRAINED_MODELS
@test (model(batchnorm = bn, pretrain = true); true)
else
@test_throws ArgumentError model(batchnorm = bn, pretrain = true)
end
@test_skip gradtest(m, rand(Float32, imsize..., 3, 2))
end
end

@testset "ResNet" begin
@testset for model in [ResNet18, ResNet34, ResNet50, ResNet101, ResNet152]
m = model()

@test size(m(rand(Float32, 256, 256, 3, 2))) == (1000, 2)
if model in PRETRAINED_MODELS
@test (model(pretrain = true); true)
else
@test_throws ArgumentError model(pretrain = true)
end
@test_skip gradtest(m, rand(Float32, 256, 256, 3, 2))
end

@testset "Shortcut C" begin
m = Metalhead.resnet(Metalhead.basicblock, :C;
channel_config = [1, 1],
block_config = [2, 2, 2, 2])

@test size(m(rand(Float32, 256, 256, 3, 2))) == (1000, 2)
end
end

@testset "ResNeXt" begin
@testset for depth in [50, 101, 152]
m = ResNeXt(depth)

@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
if ResNeXt in PRETRAINED_MODELS
@test (ResNeXt(depth, pretrain = true); true)
else
@test_throws ArgumentError ResNeXt(depth, pretrain = true)
end
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
end
end

@testset "GoogLeNet" begin
m = GoogLeNet()
@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
@test_throws ArgumentError (GoogLeNet(pretrain = true); true)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
end

@testset "Inception3" begin
m = Inception3()
@test size(m(rand(Float32, 299, 299, 3, 2))) == (1000, 2)
@test_throws ArgumentError Inception3(pretrain = true)
@test_skip gradtest(m, rand(Float32, 299, 299, 3, 2))
end

@testset "SqueezeNet" begin
m = SqueezeNet()
@test size(m(rand(Float32, 227, 227, 3, 2))) == (1000, 2)
@test_throws ArgumentError (SqueezeNet(pretrain = true); true)
@test_skip gradtest(m, rand(Float32, 227, 227, 3, 2))
end

@testset "DenseNet" begin
@testset for model in [DenseNet121, DenseNet161, DenseNet169, DenseNet201]
m = model()

@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
if model in PRETRAINED_MODELS
@test (model(pretrain = true); true)
else
@test_throws ArgumentError model(pretrain = true)
end
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
end
end

@testset "MobileNet" verbose = true begin
@testset "MobileNetv2" begin

m = MobileNetv2()

@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
if MobileNetv2 in PRETRAINED_MODELS
@test (MobileNetv2(pretrain = true); true)
else
@test_throws ArgumentError MobileNetv2(pretrain = true)
end
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
end

@testset "MobileNetv3" verbose = true begin
@testset for mode in [:small, :large]
m = MobileNetv3(mode)

@test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2)
if MobileNetv3 in PRETRAINED_MODELS
@test (MobileNetv3(mode; pretrain = true); true)
else
@test_throws ArgumentError MobileNetv3(mode; pretrain = true)
end
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))
end
end
end
Loading