Skip to content

Commit

Permalink
Merge pull request #174 from theabhirath/resnet-plus
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack authored Aug 2, 2022
2 parents 2b1fbd1 + 72cd4a9 commit 7e4f9db
Show file tree
Hide file tree
Showing 50 changed files with 2,539 additions and 2,201 deletions.
17 changes: 8 additions & 9 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,15 @@ jobs:
- x64
suite:
- '["AlexNet", "VGG"]'
- '["GoogLeNet", "SqueezeNet"]'
- '["EfficientNet", "MobileNet"]'
- '[r"/*/ResNet*", "ResNeXt"]'
- 'r"/*/Inception/Inceptionv*"'
- '["InceptionResNetv2", "Xception"]'
- '["GoogLeNet", "SqueezeNet", "MobileNet"]'
- '["EfficientNet"]'
- 'r"/*/ResNet*"'
- '[r"ResNeXt", r"SEResNet"]'
- '"Inception"'
- '"DenseNet"'
- '"ConvNeXt"'
- '"ConvMixer"'
- '"ViT"'
- '"Other"'
- '["ConvNeXt", "ConvMixer"]'
- 'r"ViTs"'
- 'r"Mixers"'
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ version = "0.8.0-DEV"
[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

Expand Down
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Pkg

Pkg.develop(path = "..")
Pkg.develop(; path = "..")

using Publish
using Artifacts, LazyArtifacts
Expand All @@ -13,5 +13,5 @@ p = Publish.Project(Metalhead)

function build_and_deploy(label)
rm(label; recursive = true, force = true)
deploy(Metalhead; root = "/Metalhead.jl", label = label)
return deploy(Metalhead; root = "/Metalhead.jl", label = label)
end
2 changes: 1 addition & 1 deletion docs/serve.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Pkg

Pkg.develop(path = "..")
Pkg.develop(; path = "..")

using Revise
using Publish
Expand Down
37 changes: 27 additions & 10 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using BSON
using Artifacts, LazyArtifacts
using Statistics
using MLUtils
using PartialFunctions
using Random

import Functors
Expand All @@ -20,38 +21,54 @@ using .Layers
# CNN models
include("convnets/alexnet.jl")
include("convnets/vgg.jl")
include("convnets/inception.jl")
include("convnets/googlenet.jl")
include("convnets/resnet.jl")
include("convnets/resnext.jl")
## ResNets
include("convnets/resnets/core.jl")
include("convnets/resnets/resnet.jl")
include("convnets/resnets/resnext.jl")
include("convnets/resnets/seresnet.jl")
## Inceptions
include("convnets/inception/googlenet.jl")
include("convnets/inception/inceptionv3.jl")
include("convnets/inception/inceptionv4.jl")
include("convnets/inception/inceptionresnetv2.jl")
include("convnets/inception/xception.jl")
## MobileNets
include("convnets/mobilenet/mobilenetv1.jl")
include("convnets/mobilenet/mobilenetv2.jl")
include("convnets/mobilenet/mobilenetv3.jl")
## Others
include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/mobilenet.jl")
include("convnets/efficientnet.jl")
include("convnets/convnext.jl")
include("convnets/convmixer.jl")

# Other models
include("other/mlpmixer.jl")
# Mixers
include("mixers/core.jl")
include("mixers/mlpmixer.jl")
include("mixers/resmlp.jl")
include("mixers/gmlp.jl")

# ViT-based models
# ViTs
include("vit-based/vit.jl")

# Load pretrained weights
include("pretrain.jl")

export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
WideResNet, SEResNet, SEResNeXt,
MLPMixer, ResMLP, gMLP,
ViT,
ConvMixer, ConvNeXt

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet,
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3,
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end
Expand Down
8 changes: 4 additions & 4 deletions src/convnets/alexnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ function alexnet(; nclasses = 1000)
Dropout(0.5),
Dense(4096, 4096, relu),
Dense(4096, nclasses)))

return layers
end

Expand All @@ -46,15 +45,16 @@ See also [`alexnet`](#).
struct AlexNet
layers::Any
end
@functor AlexNet

function AlexNet(; pretrain = false, nclasses = 1000)
layers = alexnet(; nclasses = nclasses)
pretrain && loadpretrain!(layers, "AlexNet")
if pretrain
loadpretrain!(layers, "AlexNet")
end
return AlexNet(layers)
end

@functor AlexNet

(m::AlexNet)(x) = m.layers(x)

backbone(m::AlexNet) = m.layers[1]
Expand Down
44 changes: 23 additions & 21 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,34 @@ Creates a ConvMixer model.
- `planes`: number of planes in the output of each block
- `depth`: number of layers
- `inchannels`: The number of channels in the input. The default value is 3.
- `inchannels`: The number of channels in the input.
- `kernel_size`: kernel size of the convolutional layers
- `patch_size`: size of the patches
- `activation`: activation function used after the convolutional layers
- `nclasses`: number of classes in the output
"""
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
preact = true, groups = planes,
pad = SamePad())), +),
conv_bn((1, 1), planes, planes, activation; preact = true)...)
stem = conv_norm(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
preact = true, groups = planes,
pad = SamePad())), +),
conv_norm((1, 1), planes, planes, activation; preact = true)...)
for _ in 1:depth]
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
return Chain(Chain(stem..., Chain(blocks)), head)
end

convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9),
:patch_size => (7, 7)),
:small => Dict(:planes => 768, :depth => 32, :kernel_size => (7, 7),
:patch_size => (7, 7)),
:large => Dict(:planes => 1024, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)))
const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)),
:small => Dict(:planes => 768, :depth => 32,
:kernel_size => (7, 7),
:patch_size => (7, 7)),
:large => Dict(:planes => 1024, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)))

"""
ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
Expand All @@ -45,26 +47,26 @@ Creates a ConvMixer model.
# Arguments
- `mode`: the mode of the model, either `:base`, `:small` or `:large`
- `inchannels`: The number of channels in the input. The default value is 3.
- `inchannels`: The number of channels in the input.
- `activation`: activation function used after the convolutional layers
- `nclasses`: number of classes in the output
"""
struct ConvMixer
layers::Any
end
@functor ConvMixer

function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
planes = convmixer_config[mode][:planes]
depth = convmixer_config[mode][:depth]
kernel_size = convmixer_config[mode][:kernel_size]
patch_size = convmixer_config[mode][:patch_size]
_checkconfig(mode, keys(CONVMIXER_CONFIGS))
planes = CONVMIXER_CONFIGS[mode][:planes]
depth = CONVMIXER_CONFIGS[mode][:depth]
kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size]
patch_size = CONVMIXER_CONFIGS[mode][:patch_size]
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation,
nclasses)
return ConvMixer(layers)
end

@functor ConvMixer

(m::ConvMixer)(x) = m.layers(x)

backbone(m::ConvMixer) = m.layers[1]
Expand Down
47 changes: 18 additions & 29 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Creates a single block of ConvNeXt.
([reference](https://arxiv.org/abs/2201.03545))
# Arguments:
# Arguments
- `planes`: number of input channels.
- `drop_path_rate`: Stochastic depth rate.
Expand All @@ -27,7 +27,7 @@ end
Creates the layers for a ConvNeXt model.
([reference](https://arxiv.org/abs/2201.03545))
# Arguments:
# Arguments
- `inchannels`: number of input channels.
- `depths`: list with configuration for depth of each block
Expand All @@ -39,60 +39,53 @@ Creates the layers for a ConvNeXt model.
"""
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
nclasses = 1000)
@assert length(depths)==length(planes) "`planes` should have exactly one value for each block"

@assert length(depths) == length(planes)
"`planes` should have exactly one value for each block"
downsample_layers = []
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
ChannelLayerNorm(planes[1]; ϵ = 1.0f-6))
ChannelLayerNorm(planes[1]))
push!(downsample_layers, stem)
for m in 1:(length(depths) - 1)
downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1.0f-6),
downsample_layer = Chain(ChannelLayerNorm(planes[m]),
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
push!(downsample_layers, downsample_layer)
end

stages = []
dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths))
dp_rates = linear_scheduler(drop_path_rate; depth = sum(depths))
cur = 0
for i in 1:length(depths)
for i in eachindex(depths)
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
cur += depths[i]
end

backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
head = Chain(GlobalMeanPool(),
MLUtils.flatten,
LayerNorm(planes[end]),
Dense(planes[end], nclasses))

return Chain(Chain(backbone), head)
end

# Configurations for ConvNeXt models
convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3],
:planes => [96, 192, 384, 768]),
:small => Dict(:depths => [3, 3, 27, 3],
:planes => [96, 192, 384, 768]),
:base => Dict(:depths => [3, 3, 27, 3],
:planes => [128, 256, 512, 1024]),
:large => Dict(:depths => [3, 3, 27, 3],
:planes => [192, 384, 768, 1536]),
:xlarge => Dict(:depths => [3, 3, 27, 3],
:planes => [256, 512, 1024, 2048]))
const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
:base => ([3, 3, 27, 3], [128, 256, 512, 1024]),
:large => ([3, 3, 27, 3], [192, 384, 768, 1536]),
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))

struct ConvNeXt
layers::Any
end
@functor ConvNeXt

"""
ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000)
Creates a ConvNeXt model.
([reference](https://arxiv.org/abs/2201.03545))
# Arguments:
# Arguments
- `inchannels`: The number of channels in the input. The default value is 3.
- `inchannels`: The number of channels in the input.
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
- `nclasses`: number of output classes
Expand All @@ -101,16 +94,12 @@ See also [`Metalhead.convnext`](#).
"""
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
nclasses = 1000)
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
depths = convnext_configs[mode][:depths]
planes = convnext_configs[mode][:planes]
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)
_checkconfig(mode, keys(CONVNEXT_CONFIGS))
layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, drop_path_rate, λ, nclasses)
return ConvNeXt(layers)
end

(m::ConvNeXt)(x) = m.layers(x)

@functor ConvNeXt

backbone(m::ConvNeXt) = m.layers[1]
classifier(m::ConvNeXt) = m.layers[2]
Loading

0 comments on commit 7e4f9db

Please sign in to comment.