Skip to content

Commit

Permalink
API tweaks and added docstrings for utility layers
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Feb 4, 2022
1 parent 9ec00dc commit 0c3d873
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 28 deletions.
39 changes: 22 additions & 17 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ function mlpblock(planes, hidden_planes, dropout = 0., dense = Dense; activation
dense(hidden_planes, planes, activation), Dropout(dropout))
end

# Patching layer used by many vision transformer-like models
"""
Patching{T <: Integer}
Patching layer used by many vision transformer-like models to split the input image into patches.
Can be instantiated with a tuple `(patch_height, patch_width)` or a single value `patch_size`.
"""
struct Patching{T <: Integer}
patch_height::T
patch_width::T
Expand All @@ -125,32 +129,33 @@ end

@functor Patching

# Positional embedding layer used by many vision transformer-like models
struct PosEmbedding
embedding_vector
"""
PosEmbedding{T}
Positional embedding layer used by many vision transformer-like models. Instantiated with an
embedding vector which is a learnable parameter.
"""
struct PosEmbedding{T}
embedding_vector::T
end

(p::PosEmbedding)(x) = x .+ p.embedding_vector[:, 1:size(x)[2], :]

@functor PosEmbedding

# Class tokens used by many vision transformer-like models
struct CLSTokens
cls_token
"""
CLSTokens{T}
Appends class tokens to the input that are used for classfication by many vision
transformer-like models. Instantiated with a class token vector which is a learnable parameter.
"""
struct CLSTokens{T}
cls_token::T
end

function(m::CLSTokens)(x)
cls_tokens = repeat(m.cls_token, 1, 1, size(x)[3])
x = cat(cls_tokens, x; dims = 2)
return cat(cls_tokens, x; dims = 2)
end

@functor CLSTokens

# Utility function to decide if mean pooling happens inside the model
struct CLSPooling
mode
end

(m::CLSPooling)(x) = (m.mode == "cls") ? x[:, 1, :] : _seconddimmean(x)

@functor CLSPooling
14 changes: 3 additions & 11 deletions src/vit-based/vit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ end

@functor MHAttention

struct Transformer
layers
end

"""
Transformer(planes, depth, heads, headplanes, mlppanes, dropout = 0.)
Expand All @@ -69,13 +65,9 @@ function Transformer(planes, depth, heads, headplanes, mlpplanes, dropout = 0.)
SkipConnection(prenorm(planes, mlpblock(planes, mlpplanes, dropout)), +))
for _ in 1:depth]

Transformer(Chain(layers...))
Chain(layers...)
end

(m::Transformer)(x) = m.layers(x)

@functor Transformer

"""
vit(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), planes = 1024,
depth = 6, heads = 16, mlppanes = 2048, headplanes = 64, dropout = 0.1, emb_dropout = 0.1,
Expand Down Expand Up @@ -120,7 +112,7 @@ function vit(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 1
PosEmbedding(rand(Float32, (planes, num_patches + 1, 1))),
Dropout(emb_dropout),
Transformer(planes, depth, heads, headplanes, mlppanes, dropout),
CLSPooling(pool),
(pool == "cls") ? x -> x[:, 1, :] : x -> _seconddimmean(x),
Chain(LayerNorm(planes), Dense(planes, nclasses)))
end

Expand Down Expand Up @@ -164,6 +156,6 @@ end
(m::ViT)(x) = m.layers(x)

backbone(m::ViT) = m.layers[1:end-1]
classifier(m::MLPMixer) = m.layers[end]
classifier(m::ViT) = m.layers[end]

@functor ViT

0 comments on commit 0c3d873

Please sign in to comment.