Skip to content


Merge pull request #103 from theabhirath/mlpmixer
Browse files Browse the repository at this point in the history
Implementation of MLPMixer
  • Loading branch information
darsnack authored Feb 4, 2022
2 parents 4dbdfd5 + 44de174 commit 1eb8a51
Show file tree
Hide file tree
Showing 18 changed files with 379 additions and 228 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

BSON = "0.3.2"
Flux = "0.12"
Functors = "0.2"
julia = "1.4"
NNlib = "0.7.34"
julia = "1.4"

Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

title = "Metalhead.jl"
Expand Down
2 changes: 1 addition & 1 deletion
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
| ResNeXt-152 | [`ResNeXt`]( | N |
| [MobileNetv2]( | [`MobileNetv2`]( | N |
| [MobileNetv3]( | [`MobileNetv3`]( | N |

| [MLPMixer]( | [`MLPMixer`]( | N |

## Getting Started

Expand Down
31 changes: 19 additions & 12 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,39 @@ using Flux: outputsize, Zygote
using Functors
using BSON
using Artifacts, LazyArtifacts
using Statistics

import Functors

# Models

# CNN models

# Other models

export AlexNet,
VGG, VGG11, VGG13, VGG16, VGG19,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
GoogLeNet, Inception3, SqueezeNet,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
MobileNetv2, MobileNetv3
MobileNetv2, MobileNetv3,

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
:MobileNetv2, :MobileNetv3)
:MobileNetv2, :MobileNetv3, :MLPMixer)
@eval, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
126 changes: 126 additions & 0 deletions src/layers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init],
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1)
Create a convolution + batch normalization pair with ReLU activation.
# Arguments
- `kernelsize`: size of the convolution kernel (tuple)
- `inplanes`: number of input feature maps
- `outplanes`: number of output feature maps
- `activation`: the activation function for the final layer
- `rev`: set to `true` to place the batch norm before the convolution
- `stride`: stride of the convolution kernel
- `pad`: padding of the convolution kernel
- `dilation`: dilation of the convolution kernel
- `groups`: groups for the convolution kernel
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
- `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#))
- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#))
function conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1,
layers = []

if rev
activations = (conv = activation, bn = identity)
bnplanes = inplanes
activations = (conv = identity, bn = activation)
bnplanes = outplanes

push!(layers, Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...))
push!(layers, BatchNorm(Int(bnplanes),;
initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum))

return rev ? reverse(layers) : layers

cat_channels(x, y)
Concatenate `x` and `y` along the channel dimension (third dimension).
Equivalent to `cat(x, y; dims=3)`.
Convenient binary reduction operator for use with `Parallel`.
cat_channels(x, y) = cat(x, y; dims = 3)

skip_projection(inplanes, outplanes, downsample = false)
Create a skip projection
# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: set to `true` to downsample the input
skip_projection(inplanes, outplanes, downsample = false) = downsample ?
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)...) :
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)...)

# array -> PaddedView(0, array, outplanes) for zero padding arrays
skip_identity(inplanes, outplanes[, downsample])
Create a identity projection
# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: this argument is ignored but it is needed for compatibility with [`resnet`](#).
function skip_identity(inplanes, outplanes)
if outplanes > inplanes
return Chain(MaxPool((1, 1), stride = 2),
y -> cat(y, zeros(eltype(y),
size(y, 1),
size(y, 2),
outplanes - inplanes, size(y, 4)); dims = 3))
return identity
skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes)

# Patching layer used by many vision transformer-like models
struct Patching{T <: Integer}
Patching(patch_size) = Patching(patch_size, patch_size)

function (p::Patching)(x)
h, w, c, n = size(x)
hp, wp = h ÷ p.patch_height, w ÷ p.patch_width
xpatch = reshape(x, hp, p.patch_height, wp, p.patch_width, c, n)

return reshape(permutedims(xpatch, (1, 3, 5, 2, 4, 6)), p.patch_height * p.patch_width * c,
hp * wp, n)

@functor Patching

mlpblock(planes, expansion_factor = 4, dropout = 0., dense = Dense)
Feedforward block used in many vision transformer-like models.
# Arguments
`planes`: Number of dimensions in the input and output.
`hidden_planes`: Number of dimensions in the intermediate layer.
`dropout`: Dropout rate.
`dense`: Type of dense layer to use in the feedforward block.
`activation`: Activation function to use.
function mlpblock(planes, hidden_planes, dropout = 0., dense = Dense; activation = gelu)
Chain(dense(planes, hidden_planes, activation), Dropout(dropout),
dense(hidden_planes, planes, activation), Dropout(dropout))
87 changes: 87 additions & 0 deletions src/other/mlpmixer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Utility function for creating a residual block with LayerNorm before the residual connection
_residualprenorm(planes, fn) = SkipConnection(Chain(fn, LayerNorm(planes)), +)

# Utility function for 1D convolution
_conv1d(inplanes, outplanes, activation) = Conv((1, ), inplanes => outplanes, activation)

mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix =
_conv1d, channel_mix = Dense))
Creates a model with the MLPMixer architecture.
# Arguments
- imsize: the size of the input image
- inchannels: the number of input channels
- patch_size: the size of the patches
- planes: the number of channels fed into the main model
- depth: the number of blocks in the main model
- expansion_factor: the number of channels in each block
- dropout: the dropout rate
- nclasses: the number of classes in the output
- token_mix: the function to use for the token mixing layer
- channel_mix: the function to use for the channel mixing layer
function mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix =
_conv1d, channel_mix = Dense)

im_height, im_width = imsize

@assert (im_height % patch_size) == 0 && (im_width % patch_size == 0)
"image size must be divisible by patch size"

num_patches = (im_height ÷ patch_size) * (im_width ÷ patch_size)

layers = []
push!(layers, Patching(patch_size))
push!(layers, Dense((patch_size ^ 2) * inchannels, planes))
append!(layers, [Chain(_residualprenorm(planes, mlpblock(num_patches,
expansion_factor * num_patches,
dropout, token_mix)),
_residualprenorm(planes, mlpblock(planes,
expansion_factor * planes, dropout,
channel_mix)),) for _ in 1:depth])

classification_head = Chain(_seconddimmean, Dense(planes, nclasses))

return Chain(Chain(layers...), classification_head)

struct MLPMixer

MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000)
Creates a model with the MLPMixer architecture.
# Arguments
- imsize: the size of the input image
- inchannels: the number of input channels
- patch_size: the size of the patches
- planes: the number of channels fed into the main model
- depth: the number of blocks in the main model
- expansion_factor: the number of channels in each block
- dropout: the dropout rate
- nclasses: the number of classes in the output
function MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000)

layers = mlpmixer(imsize; inchannels, patch_size, planes, depth, expansion_factor, dropout,

@functor MLPMixer

(m::MLPMixer)(x) = m.layers(x)

backbone(m::MLPMixer) = m.layers[1]
classifier(m::MLPMixer) = m.layers[2]
93 changes: 2 additions & 91 deletions src/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,94 +1,5 @@
conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init],
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1)
Create a convolution + batch normalization pair with ReLU activation.
# Arguments
- `kernelsize`: size of the convolution kernel (tuple)
- `inplanes`: number of input feature maps
- `outplanes`: number of output feature maps
- `activation`: the activation function for the final layer
- `rev`: set to `true` to place the batch norm before the convolution
- `stride`: stride of the convolution kernel
- `pad`: padding of the convolution kernel
- `dilation`: dilation of the convolution kernel
- `groups`: groups for the convolution kernel
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
- `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#))
- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#))
function conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1,
layers = []

if rev
activations = (conv = activation, bn = identity)
bnplanes = inplanes
activations = (conv = identity, bn = activation)
bnplanes = outplanes

push!(layers, Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...))
push!(layers, BatchNorm(Int(bnplanes),;
initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum))

return rev ? reverse(layers) : layers

cat_channels(x, y)
Concatenate `x` and `y` along the channel dimension (third dimension).
Equivalent to `cat(x, y; dims=3)`.
Convenient binary reduction operator for use with `Parallel`.
cat_channels(x, y) = cat(x, y; dims = 3)

skip_projection(inplanes, outplanes, downsample = false)
Create a skip projection
# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: set to `true` to downsample the input
skip_projection(inplanes, outplanes, downsample = false) = downsample ?
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)...) :
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)...)

# array -> PaddedView(0, array, outplanes) for zero padding arrays
skip_identity(inplanes, outplanes[, downsample])
Create a identity projection
# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: this argument is ignored but it is needed for compatibility with [`resnet`](#).
function skip_identity(inplanes, outplanes)
if outplanes > inplanes
return Chain(MaxPool((1, 1), stride = 2),
y -> cat(y, zeros(eltype(y),
size(y, 1),
size(y, 2),
outplanes - inplanes, size(y, 4)); dims = 3))
return identity
skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes)
# Utility function for classifier head of vision transformer-like models
_seconddimmean(x) = mean(x, dims = 2)[:, 1, :]

Expand Down

0 comments on commit 1eb8a51

Please sign in to comment.