Skip to content

Commit

Permalink
Merge pull request #103 from theabhirath/mlpmixer
Browse files Browse the repository at this point in the history
Implementation of MLPMixer
  • Loading branch information
darsnack authored Feb 4, 2022
2 parents 4dbdfd5 + 44de174 commit 1eb8a51
Show file tree
Hide file tree
Showing 18 changed files with 379 additions and 228 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
BSON = "0.3.2"
Flux = "0.12"
Functors = "0.2"
julia = "1.4"
NNlib = "0.7.34"
julia = "1.4"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

[publish]
title = "Metalhead.jl"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
| ResNeXt-152 | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNeXt.html) | N |
| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv2.html) | N |
| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv3.html) | N |

| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MLPMixer.html) | N |

## Getting Started

Expand Down
31 changes: 19 additions & 12 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,39 @@ using Flux: outputsize, Zygote
using Functors
using BSON
using Artifacts, LazyArtifacts
using Statistics

import Functors

# Models
include("utilities.jl")
include("alexnet.jl")
include("vgg.jl")
include("resnet.jl")
include("googlenet.jl")
include("inception.jl")
include("squeezenet.jl")
include("densenet.jl")
include("resnext.jl")
include("mobilenet.jl")
include("layers.jl")

# CNN models
include("convnets/alexnet.jl")
include("convnets/vgg.jl")
include("convnets/inception.jl")
include("convnets/googlenet.jl")
include("convnets/resnet.jl")
include("convnets/resnext.jl")
include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/mobilenet.jl")

# Other models
include("other/mlpmixer.jl")

export AlexNet,
VGG, VGG11, VGG13, VGG16, VGG19,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
GoogLeNet, Inception3, SqueezeNet,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
ResNeXt,
MobileNetv2, MobileNetv3
MobileNetv2, MobileNetv3,
MLPMixer

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
:MobileNetv2, :MobileNetv3)
:MobileNetv2, :MobileNetv3, :MLPMixer)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
126 changes: 126 additions & 0 deletions src/layers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init],
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1)
Create a convolution + batch normalization pair with ReLU activation.
# Arguments
- `kernelsize`: size of the convolution kernel (tuple)
- `inplanes`: number of input feature maps
- `outplanes`: number of output feature maps
- `activation`: the activation function for the final layer
- `rev`: set to `true` to place the batch norm before the convolution
- `stride`: stride of the convolution kernel
- `pad`: padding of the convolution kernel
- `dilation`: dilation of the convolution kernel
- `groups`: groups for the convolution kernel
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
- `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#))
- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#))
"""
function conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1,
kwargs...)
layers = []

if rev
activations = (conv = activation, bn = identity)
bnplanes = inplanes
else
activations = (conv = identity, bn = activation)
bnplanes = outplanes
end

push!(layers, Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...))
push!(layers, BatchNorm(Int(bnplanes), activations.bn;
initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum))

return rev ? reverse(layers) : layers
end

"""
cat_channels(x, y)
Concatenate `x` and `y` along the channel dimension (third dimension).
Equivalent to `cat(x, y; dims=3)`.
Convenient binary reduction operator for use with `Parallel`.
"""
cat_channels(x, y) = cat(x, y; dims = 3)

"""
skip_projection(inplanes, outplanes, downsample = false)
Create a skip projection
([reference](https://arxiv.org/abs/1512.03385v1)).
# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: set to `true` to downsample the input
"""
skip_projection(inplanes, outplanes, downsample = false) = downsample ?
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)...) :
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)...)

# array -> PaddedView(0, array, outplanes) for zero padding arrays
"""
skip_identity(inplanes, outplanes[, downsample])
Create a identity projection
([reference](https://arxiv.org/abs/1512.03385v1)).
# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: this argument is ignored but it is needed for compatibility with [`resnet`](#).
"""
function skip_identity(inplanes, outplanes)
if outplanes > inplanes
return Chain(MaxPool((1, 1), stride = 2),
y -> cat(y, zeros(eltype(y),
size(y, 1),
size(y, 2),
outplanes - inplanes, size(y, 4)); dims = 3))
else
return identity
end
end
skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes)

# Patching layer used by many vision transformer-like models
struct Patching{T <: Integer}
patch_height::T
patch_width::T
end
Patching(patch_size) = Patching(patch_size, patch_size)

function (p::Patching)(x)
h, w, c, n = size(x)
hp, wp = h ÷ p.patch_height, w ÷ p.patch_width
xpatch = reshape(x, hp, p.patch_height, wp, p.patch_width, c, n)

return reshape(permutedims(xpatch, (1, 3, 5, 2, 4, 6)), p.patch_height * p.patch_width * c,
hp * wp, n)
end

@functor Patching

"""
mlpblock(planes, expansion_factor = 4, dropout = 0., dense = Dense)
Feedforward block used in many vision transformer-like models.
# Arguments
`planes`: Number of dimensions in the input and output.
`hidden_planes`: Number of dimensions in the intermediate layer.
`dropout`: Dropout rate.
`dense`: Type of dense layer to use in the feedforward block.
`activation`: Activation function to use.
"""
function mlpblock(planes, hidden_planes, dropout = 0., dense = Dense; activation = gelu)
Chain(dense(planes, hidden_planes, activation), Dropout(dropout),
dense(hidden_planes, planes, activation), Dropout(dropout))
end
87 changes: 87 additions & 0 deletions src/other/mlpmixer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Utility function for creating a residual block with LayerNorm before the residual connection
_residualprenorm(planes, fn) = SkipConnection(Chain(fn, LayerNorm(planes)), +)

# Utility function for 1D convolution
_conv1d(inplanes, outplanes, activation) = Conv((1, ), inplanes => outplanes, activation)

"""
mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix =
_conv1d, channel_mix = Dense))
Creates a model with the MLPMixer architecture.
([reference](https://arxiv.org/pdf/2105.01601)).
# Arguments
- imsize: the size of the input image
- inchannels: the number of input channels
- patch_size: the size of the patches
- planes: the number of channels fed into the main model
- depth: the number of blocks in the main model
- expansion_factor: the number of channels in each block
- dropout: the dropout rate
- nclasses: the number of classes in the output
- token_mix: the function to use for the token mixing layer
- channel_mix: the function to use for the channel mixing layer
"""
function mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix =
_conv1d, channel_mix = Dense)

im_height, im_width = imsize

@assert (im_height % patch_size) == 0 && (im_width % patch_size == 0)
"image size must be divisible by patch size"

num_patches = (im_height ÷ patch_size) * (im_width ÷ patch_size)

layers = []
push!(layers, Patching(patch_size))
push!(layers, Dense((patch_size ^ 2) * inchannels, planes))
append!(layers, [Chain(_residualprenorm(planes, mlpblock(num_patches,
expansion_factor * num_patches,
dropout, token_mix)),
_residualprenorm(planes, mlpblock(planes,
expansion_factor * planes, dropout,
channel_mix)),) for _ in 1:depth])

classification_head = Chain(_seconddimmean, Dense(planes, nclasses))

return Chain(Chain(layers...), classification_head)
end

struct MLPMixer
layers
end

"""
MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000)
Creates a model with the MLPMixer architecture.
([reference](https://arxiv.org/pdf/2105.01601)).
# Arguments
- imsize: the size of the input image
- inchannels: the number of input channels
- patch_size: the size of the patches
- planes: the number of channels fed into the main model
- depth: the number of blocks in the main model
- expansion_factor: the number of channels in each block
- dropout: the dropout rate
- nclasses: the number of classes in the output
"""
function MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000)

layers = mlpmixer(imsize; inchannels, patch_size, planes, depth, expansion_factor, dropout,
nclasses)
MLPMixer(layers)
end

@functor MLPMixer

(m::MLPMixer)(x) = m.layers(x)

backbone(m::MLPMixer) = m.layers[1]
classifier(m::MLPMixer) = m.layers[2]
93 changes: 2 additions & 91 deletions src/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,94 +1,5 @@
"""
conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init],
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1)
Create a convolution + batch normalization pair with ReLU activation.
# Arguments
- `kernelsize`: size of the convolution kernel (tuple)
- `inplanes`: number of input feature maps
- `outplanes`: number of output feature maps
- `activation`: the activation function for the final layer
- `rev`: set to `true` to place the batch norm before the convolution
- `stride`: stride of the convolution kernel
- `pad`: padding of the convolution kernel
- `dilation`: dilation of the convolution kernel
- `groups`: groups for the convolution kernel
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
- `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#))
- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#))
"""
function conv_bn(kernelsize, inplanes, outplanes, activation = relu;
rev = false,
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1,
kwargs...)
layers = []

if rev
activations = (conv = activation, bn = identity)
bnplanes = inplanes
else
activations = (conv = identity, bn = activation)
bnplanes = outplanes
end

push!(layers, Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...))
push!(layers, BatchNorm(Int(bnplanes), activations.bn;
initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum))

return rev ? reverse(layers) : layers
end

"""
cat_channels(x, y)
Concatenate `x` and `y` along the channel dimension (third dimension).
Equivalent to `cat(x, y; dims=3)`.
Convenient binary reduction operator for use with `Parallel`.
"""
cat_channels(x, y) = cat(x, y; dims = 3)

"""
skip_projection(inplanes, outplanes, downsample = false)
Create a skip projection
([reference](https://arxiv.org/abs/1512.03385v1)).
# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: set to `true` to downsample the input
"""
skip_projection(inplanes, outplanes, downsample = false) = downsample ?
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)...) :
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)...)

# array -> PaddedView(0, array, outplanes) for zero padding arrays
"""
skip_identity(inplanes, outplanes[, downsample])
Create a identity projection
([reference](https://arxiv.org/abs/1512.03385v1)).
# Arguments:
- `inplanes`: the number of input feature maps
- `outplanes`: the number of output feature maps
- `downsample`: this argument is ignored but it is needed for compatibility with [`resnet`](#).
"""
function skip_identity(inplanes, outplanes)
if outplanes > inplanes
return Chain(MaxPool((1, 1), stride = 2),
y -> cat(y, zeros(eltype(y),
size(y, 1),
size(y, 2),
outplanes - inplanes, size(y, 4)); dims = 3))
else
return identity
end
end
skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes)
# Utility function for classifier head of vision transformer-like models
_seconddimmean(x) = mean(x, dims = 2)[:, 1, :]

"""
weights(model)
Expand Down
Loading

0 comments on commit 1eb8a51

Please sign in to comment.