Skip to content

Commit

Permalink
More uniformity + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Aug 3, 2022
1 parent 8ce0dce commit bb0bd62
Show file tree
Hide file tree
Showing 30 changed files with 202 additions and 230 deletions.
8 changes: 3 additions & 5 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,12 @@ include("vit-based/vit.jl")
include("pretrain.jl")

export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
WideResNet, ResNeXt, SEResNet, SEResNeXt,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
WideResNet, SEResNet, SEResNeXt,
MLPMixer, ResMLP, gMLP,
ViT,
ConvMixer, ConvNeXt
MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,
Expand Down
19 changes: 12 additions & 7 deletions src/convnets/alexnet.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""
alexnet(; nclasses::Integer = 1000)
alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
Create an AlexNet model
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
# Arguments
- `inchannels`: The number of input channels.
- `nclasses`: the number of output classes
"""
function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
Expand All @@ -27,19 +28,23 @@ function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
end

"""
AlexNet(; pretrain::Bool = false, nclasses::Integer = 1000)
AlexNet(; pretrain::Bool = false, inchannels::Integer = 3,
nclasses::Integer = 1000)
Create a `AlexNet`.
See also [`alexnet`](#).
!!! warning
`AlexNet` does not currently support pretrained weights.
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
# Arguments
- `pretrain`: set to `true` to load pre-trained weights for ImageNet
- `inchannels`: The number of input channels.
- `nclasses`: the number of output classes
!!! warning
`AlexNet` does not currently support pretrained weights.
See also [`alexnet`](#).
"""
struct AlexNet
layers::Any
Expand Down
35 changes: 16 additions & 19 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,28 @@ function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
pad = SamePad())), +),
conv_norm((1, 1), planes, planes, activation; preact = true)...)
for _ in 1:depth]
return Chain(Chain(stem..., Chain(blocks)), create_classifier(planes, nclasses))
return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses))
end

const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)),
:small => Dict(:planes => 768, :depth => 32,
:kernel_size => (7, 7),
:patch_size => (7, 7)),
:large => Dict(:planes => 1024, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)))
const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20),
(kernel_size = (9, 9),
patch_size = (7, 7))),
:small => ((768, 32),
(kernel_size = (7, 7),
patch_size = (7, 7))),
:large => ((1024, 20),
(kernel_size = (9, 9),
patch_size = (7, 7))))

"""
ConvMixer(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
ConvMixer(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
Creates a ConvMixer model.
([reference](https://arxiv.org/abs/2201.09792))
# Arguments
- `mode`: the mode of the model, either `:base`, `:small` or `:large`
- `config`: the size of the model, either `:base`, `:small` or `:large`
- `inchannels`: The number of channels in the input.
- `nclasses`: number of classes in the output
"""
Expand All @@ -56,13 +56,10 @@ struct ConvMixer
end
@functor ConvMixer

function ConvMixer(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(mode, keys(CONVMIXER_CONFIGS))
planes = CONVMIXER_CONFIGS[mode][:planes]
depth = CONVMIXER_CONFIGS[mode][:depth]
kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size]
patch_size = CONVMIXER_CONFIGS[mode][:patch_size]
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, nclasses)
function ConvMixer(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(config, keys(CONVMIXER_CONFIGS))
layers = convmixer(CONVMIXER_CONFIGS[config][1]...; CONVMIXER_CONFIGS[config][2]...,
inchannels, nclasses)
return ConvMixer(layers)
end

Expand Down
30 changes: 15 additions & 15 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function convnextblock(planes::Integer, drop_path_rate = 0.0, layerscale_init =
end

"""
convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer};
convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer};
drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3,
nclasses::Integer = 1000)
Expand All @@ -31,27 +31,27 @@ Creates the layers for a ConvNeXt model.
# Arguments
- `inchannels`: number of input channels.
- `depths`: list with configuration for depth of each block
- `planes`: list with configuration for number of output channels in each block
- `drop_path_rate`: Stochastic depth rate.
- `layerscale_init`: Initial value for [`LayerScale`](#)
([reference](https://arxiv.org/abs/2103.17239))
- `inchannels`: number of input channels.
- `nclasses`: number of output classes
"""
function convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer};
function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer};
drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3,
nclasses::Integer = 1000)
@assert length(depths) == length(planes)
"`planes` should have exactly one value for each block"
downsample_layers = []
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
ChannelLayerNorm(planes[1]))
push!(downsample_layers, stem)
push!(downsample_layers,
Chain(conv_norm((4, 4), inchannels => planes[1]; stride = 4,
norm_layer = ChannelLayerNorm)...))
for m in 1:(length(depths) - 1)
downsample_layer = Chain(ChannelLayerNorm(planes[m]),
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
push!(downsample_layers, downsample_layer)
push!(downsample_layers,
Chain(conv_norm((2, 2), planes[m] => planes[m + 1]; stride = 2,
norm_layer = ChannelLayerNorm, revnorm = true)...))
end
stages = []
dp_rates = linear_scheduler(drop_path_rate; depth = sum(depths))
Expand All @@ -64,8 +64,7 @@ function convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer};
end
backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
classifier = Chain(GlobalMeanPool(), MLUtils.flatten,
LayerNorm(planes[end]),
Dense(planes[end], nclasses))
LayerNorm(planes[end]), Dense(planes[end], nclasses))
return Chain(Chain(backbone...), classifier)
end

Expand All @@ -77,13 +76,14 @@ const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))

"""
ConvNeXt(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
Creates a ConvNeXt model.
([reference](https://arxiv.org/abs/2201.03545))
# Arguments
- `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`.
- `inchannels`: The number of channels in the input.
- `nclasses`: number of output classes
Expand All @@ -94,9 +94,9 @@ struct ConvNeXt
end
@functor ConvNeXt

function ConvNeXt(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(mode, keys(CONVNEXT_CONFIGS))
layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, nclasses)
function ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(config, keys(CONVNEXT_CONFIGS))
layers = convnext(CONVNEXT_CONFIGS[config]...; inchannels, nclasses)
return ConvNeXt(layers)
end

Expand Down
3 changes: 2 additions & 1 deletion src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ Create a DenseNet model
- `reduction`: the factor by which the number of feature maps is scaled across each transition
- `nclasses`: the number of output classes
"""
function densenet(nblocks::Vector{<:Integer}; growth_rate::Integer = 32, reduction = 0.5,
function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32,
reduction = 0.5,
inchannels::Integer = 3, nclasses::Integer = 1000)
return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks];
reduction, inchannels, nclasses)
Expand Down
57 changes: 15 additions & 42 deletions src/convnets/efficientnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
- `max_width`: maximum number of output channels before the fully connected
classification blocks
"""
function efficientnet(scalings, block_configs; max_width::Integer = 1280,
inchannels::Integer = 3, nclasses::Integer = 1000)
function efficientnet(scalings::NTuple{2, Real},
block_configs::AbstractVector{NTuple{6, Int}};
max_width::Integer = 1280, inchannels::Integer = 3,
nclasses::Integer = 1000)
wscale, dscale = scalings
scalew(w) = wscale 1 ? w : ceil(Int64, wscale * w)
scaled(d) = dscale 1 ? d : ceil(Int64, dscale * d)
Expand Down Expand Up @@ -83,61 +85,32 @@ const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)),
:b8 => (672, (2.2, 3.6)))

"""
EfficientNet(scalings, block_configs; max_width::Integer = 1280,
inchannels::Integer = 3, nclasses::Integer = 1000)
EfficientNet(config::Symbol; pretrain::Bool = false)
Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
See also [`efficientnet`](#).
# Arguments
- `scalings`: global width and depth scaling (given as a tuple)
- `block_configs`: configuration for each inverted residual block,
given as a vector of tuples with elements:
+ `n`: number of block repetitions (will be scaled by global depth scaling)
+ `k`: kernel size
+ `s`: kernel stride
+ `e`: expansion ratio
+ `i`: block input channels (will be scaled by global width scaling)
+ `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
- `config`: name of default configuration
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
"""
struct EfficientNet
layers::Any
end
@functor EfficientNet

function EfficientNet(scalings, block_configs; max_width::Integer = 1280,
inchannels::Integer = 3, nclasses::Integer = 1000)
layers = efficientnet(scalings, block_configs; inchannels, nclasses, max_width)
return EfficientNet(layers)
function EfficientNet(config::Symbol; pretrain::Bool = false)
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
model = efficientnet(EFFICIENTNET_GLOBAL_CONFIGS[config][2], EFFICIENTNET_BLOCK_CONFIGS)
if pretrain
loadpretrain!(model, string("efficientnet-", config))
end
return model
end

(m::EfficientNet)(x) = m.layers(x)

backbone(m::EfficientNet) = m.layers[1]
classifier(m::EfficientNet) = m.layers[2]

"""
EfficientNet(name::Symbol; pretrain::Bool = false)
Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
See also [`efficientnet`](#).
# Arguments
- `name`: name of default configuration
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
"""
function EfficientNet(name::Symbol; pretrain::Bool = false)
_checkconfig(name, keys(EFFICIENTNET_GLOBAL_CONFIGS))
model = EfficientNet(EFFICIENTNET_GLOBAL_CONFIGS[name][2], EFFICIENTNET_BLOCK_CONFIGS)
pretrain && loadpretrain!(model, string("efficientnet-", name))
return model
end
3 changes: 1 addition & 2 deletions src/convnets/inception/googlenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ function googlenet(; inchannels::Integer = 3, nclasses::Integer = 1000)
MaxPool((3, 3); stride = 2, pad = 1),
_inceptionblock(832, 256, 160, 320, 32, 128, 128),
_inceptionblock(832, 384, 192, 384, 48, 128, 128))
classifier = create_classifier(1024, nclasses; dropout_rate = 0.4)
return Chain(backbone, classifier)
return Chain(backbone, create_classifier(1024, nclasses; dropout_rate = 0.4))
end

"""
Expand Down
3 changes: 1 addition & 2 deletions src/convnets/inception/inceptionresnetv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0,
[block8(0.20f0) for _ in 1:9]...,
block8(; activation = relu),
conv_norm((1, 1), 2080, 1536)...)
classifier = create_classifier(1536, nclasses; dropout_rate)
return Chain(backbone, classifier)
return Chain(backbone, create_classifier(1536, nclasses; dropout_rate))
end

"""
Expand Down
3 changes: 1 addition & 2 deletions src/convnets/inception/inceptionv3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ function inceptionv3(; inchannels::Integer = 3, nclasses::Integer = 1000)
inceptionv3_d(768),
inceptionv3_e(1280),
inceptionv3_e(2048))
classifier = create_classifier(2048, nclasses; dropout_rate = 0.2)
return Chain(backbone, classifier)
return Chain(backbone, create_classifier(2048, nclasses; dropout_rate = 0.2))
end

"""
Expand Down
3 changes: 1 addition & 2 deletions src/convnets/inception/inceptionv4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3,
inceptionv4_c(),
inceptionv4_c(),
inceptionv4_c())
classifier = create_classifier(1536, nclasses; dropout_rate)
return Chain(backbone, classifier)
return Chain(backbone, create_classifier(1536, nclasses; dropout_rate))
end

"""
Expand Down
4 changes: 2 additions & 2 deletions src/convnets/inception/xception.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int
end

"""
xception(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000)
xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000)
Creates an Xception model.
([reference](https://arxiv.org/abs/1610.02357))
# Arguments
- `inchannels`: number of input channels.
- `dropout_rate`: rate of dropout in classifier head.
- `inchannels`: number of input channels.
- `nclasses`: the number of output classes.
"""
function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000)
Expand Down
8 changes: 4 additions & 4 deletions src/convnets/mobilenet/mobilenetv1.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
mobilenetv1(width_mult::Number, config::Vector{<:Tuple}; activation = relu,
mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu,
inchannels::Integer = 3, nclasses::Integer = 1000)
Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)).
Expand All @@ -19,11 +19,11 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)).
- `inchannels`: The number of input channels. The default value is 3.
- `nclasses`: The number of output classes
"""
function mobilenetv1(width_mult::Number, config::Vector{<:Tuple}; activation = relu,
function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu,
inchannels::Integer = 3, nclasses::Integer = 1000)
layers = []
for (dw, outch, stride, nrepeats) in config
outch = Int(outch * width_mult)
outch = floor(Int, outch * width_mult)
for _ in 1:nrepeats
layer = dw ?
depthwise_sep_conv_norm((3, 3), inchannels, outch, activation;
Expand Down Expand Up @@ -76,7 +76,7 @@ struct MobileNetv1
end
@functor MobileNetv1

function MobileNetv1(width_mult::Number = 1; pretrain::Bool = false,
function MobileNetv1(width_mult::Real = 1; pretrain::Bool = false,
inchannels::Integer = 3, nclasses::Integer = 1000)
layers = mobilenetv1(width_mult, MOBILENETV1_CONFIGS; inchannels, nclasses)
if pretrain
Expand Down
Loading

0 comments on commit bb0bd62

Please sign in to comment.