-
-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #103 from theabhirath/mlpmixer
Implementation of MLPMixer
- Loading branch information
Showing
18 changed files
with
379 additions
and
228 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
""" | ||
conv_bn(kernelsize, inplanes, outplanes, activation = relu; | ||
rev = false, | ||
stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init], | ||
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1) | ||
Create a convolution + batch normalization pair with ReLU activation. | ||
# Arguments | ||
- `kernelsize`: size of the convolution kernel (tuple) | ||
- `inplanes`: number of input feature maps | ||
- `outplanes`: number of output feature maps | ||
- `activation`: the activation function for the final layer | ||
- `rev`: set to `true` to place the batch norm before the convolution | ||
- `stride`: stride of the convolution kernel | ||
- `pad`: padding of the convolution kernel | ||
- `dilation`: dilation of the convolution kernel | ||
- `groups`: groups for the convolution kernel | ||
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) | ||
- `initβ`, `initγ`: initialization for the batch norm (see [`Flux.BatchNorm`](#)) | ||
- `ϵ`, `momentum`: batch norm parameters (see [`Flux.BatchNorm`](#)) | ||
""" | ||
function conv_bn(kernelsize, inplanes, outplanes, activation = relu; | ||
rev = false, | ||
initβ = Flux.zeros32, initγ = Flux.ones32, ϵ = 1f-5, momentum = 1f-1, | ||
kwargs...) | ||
layers = [] | ||
|
||
if rev | ||
activations = (conv = activation, bn = identity) | ||
bnplanes = inplanes | ||
else | ||
activations = (conv = identity, bn = activation) | ||
bnplanes = outplanes | ||
end | ||
|
||
push!(layers, Conv(kernelsize, Int(inplanes) => Int(outplanes), activations.conv; kwargs...)) | ||
push!(layers, BatchNorm(Int(bnplanes), activations.bn; | ||
initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum)) | ||
|
||
return rev ? reverse(layers) : layers | ||
end | ||
|
||
""" | ||
cat_channels(x, y) | ||
Concatenate `x` and `y` along the channel dimension (third dimension). | ||
Equivalent to `cat(x, y; dims=3)`. | ||
Convenient binary reduction operator for use with `Parallel`. | ||
""" | ||
cat_channels(x, y) = cat(x, y; dims = 3) | ||
|
||
""" | ||
skip_projection(inplanes, outplanes, downsample = false) | ||
Create a skip projection | ||
([reference](https://arxiv.org/abs/1512.03385v1)). | ||
# Arguments: | ||
- `inplanes`: the number of input feature maps | ||
- `outplanes`: the number of output feature maps | ||
- `downsample`: set to `true` to downsample the input | ||
""" | ||
skip_projection(inplanes, outplanes, downsample = false) = downsample ? | ||
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)...) : | ||
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)...) | ||
|
||
# array -> PaddedView(0, array, outplanes) for zero padding arrays | ||
""" | ||
skip_identity(inplanes, outplanes[, downsample]) | ||
Create a identity projection | ||
([reference](https://arxiv.org/abs/1512.03385v1)). | ||
# Arguments: | ||
- `inplanes`: the number of input feature maps | ||
- `outplanes`: the number of output feature maps | ||
- `downsample`: this argument is ignored but it is needed for compatibility with [`resnet`](#). | ||
""" | ||
function skip_identity(inplanes, outplanes) | ||
if outplanes > inplanes | ||
return Chain(MaxPool((1, 1), stride = 2), | ||
y -> cat(y, zeros(eltype(y), | ||
size(y, 1), | ||
size(y, 2), | ||
outplanes - inplanes, size(y, 4)); dims = 3)) | ||
else | ||
return identity | ||
end | ||
end | ||
skip_identity(inplanes, outplanes, downsample) = skip_identity(inplanes, outplanes) | ||
|
||
# Patching layer used by many vision transformer-like models | ||
struct Patching{T <: Integer} | ||
patch_height::T | ||
patch_width::T | ||
end | ||
Patching(patch_size) = Patching(patch_size, patch_size) | ||
|
||
function (p::Patching)(x) | ||
h, w, c, n = size(x) | ||
hp, wp = h ÷ p.patch_height, w ÷ p.patch_width | ||
xpatch = reshape(x, hp, p.patch_height, wp, p.patch_width, c, n) | ||
|
||
return reshape(permutedims(xpatch, (1, 3, 5, 2, 4, 6)), p.patch_height * p.patch_width * c, | ||
hp * wp, n) | ||
end | ||
|
||
@functor Patching | ||
|
||
""" | ||
mlpblock(planes, expansion_factor = 4, dropout = 0., dense = Dense) | ||
Feedforward block used in many vision transformer-like models. | ||
# Arguments | ||
`planes`: Number of dimensions in the input and output. | ||
`hidden_planes`: Number of dimensions in the intermediate layer. | ||
`dropout`: Dropout rate. | ||
`dense`: Type of dense layer to use in the feedforward block. | ||
`activation`: Activation function to use. | ||
""" | ||
function mlpblock(planes, hidden_planes, dropout = 0., dense = Dense; activation = gelu) | ||
Chain(dense(planes, hidden_planes, activation), Dropout(dropout), | ||
dense(hidden_planes, planes, activation), Dropout(dropout)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Utility function for creating a residual block with LayerNorm before the residual connection | ||
_residualprenorm(planes, fn) = SkipConnection(Chain(fn, LayerNorm(planes)), +) | ||
|
||
# Utility function for 1D convolution | ||
_conv1d(inplanes, outplanes, activation) = Conv((1, ), inplanes => outplanes, activation) | ||
|
||
""" | ||
mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, | ||
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix = | ||
_conv1d, channel_mix = Dense)) | ||
Creates a model with the MLPMixer architecture. | ||
([reference](https://arxiv.org/pdf/2105.01601)). | ||
# Arguments | ||
- imsize: the size of the input image | ||
- inchannels: the number of input channels | ||
- patch_size: the size of the patches | ||
- planes: the number of channels fed into the main model | ||
- depth: the number of blocks in the main model | ||
- expansion_factor: the number of channels in each block | ||
- dropout: the dropout rate | ||
- nclasses: the number of classes in the output | ||
- token_mix: the function to use for the token mixing layer | ||
- channel_mix: the function to use for the channel mixing layer | ||
""" | ||
function mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, | ||
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix = | ||
_conv1d, channel_mix = Dense) | ||
|
||
im_height, im_width = imsize | ||
|
||
@assert (im_height % patch_size) == 0 && (im_width % patch_size == 0) | ||
"image size must be divisible by patch size" | ||
|
||
num_patches = (im_height ÷ patch_size) * (im_width ÷ patch_size) | ||
|
||
layers = [] | ||
push!(layers, Patching(patch_size)) | ||
push!(layers, Dense((patch_size ^ 2) * inchannels, planes)) | ||
append!(layers, [Chain(_residualprenorm(planes, mlpblock(num_patches, | ||
expansion_factor * num_patches, | ||
dropout, token_mix)), | ||
_residualprenorm(planes, mlpblock(planes, | ||
expansion_factor * planes, dropout, | ||
channel_mix)),) for _ in 1:depth]) | ||
|
||
classification_head = Chain(_seconddimmean, Dense(planes, nclasses)) | ||
|
||
return Chain(Chain(layers...), classification_head) | ||
end | ||
|
||
struct MLPMixer | ||
layers | ||
end | ||
|
||
""" | ||
MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, | ||
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000) | ||
Creates a model with the MLPMixer architecture. | ||
([reference](https://arxiv.org/pdf/2105.01601)). | ||
# Arguments | ||
- imsize: the size of the input image | ||
- inchannels: the number of input channels | ||
- patch_size: the size of the patches | ||
- planes: the number of channels fed into the main model | ||
- depth: the number of blocks in the main model | ||
- expansion_factor: the number of channels in each block | ||
- dropout: the dropout rate | ||
- nclasses: the number of classes in the output | ||
""" | ||
function MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, | ||
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000) | ||
|
||
layers = mlpmixer(imsize; inchannels, patch_size, planes, depth, expansion_factor, dropout, | ||
nclasses) | ||
MLPMixer(layers) | ||
end | ||
|
||
@functor MLPMixer | ||
|
||
(m::MLPMixer)(x) = m.layers(x) | ||
|
||
backbone(m::MLPMixer) = m.layers[1] | ||
classifier(m::MLPMixer) = m.layers[2] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.