Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul of ResNet API #174

Merged
merged 67 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
cd0edef
Add `DropBlock`
theabhirath Jun 16, 2022
271b430
Initial commit for new ResNet API
theabhirath Jun 21, 2022
866dbcc
Cleanup
theabhirath Jun 22, 2022
a038ff8
Get some stuff to work
theabhirath Jun 23, 2022
de079bc
Tweaks - I
theabhirath Jun 23, 2022
4fa28d4
Make pretrain condition explicit
theabhirath Jun 25, 2022
7846f8b
More declarative interface for ResNet
theabhirath Jun 28, 2022
a1d5ddc
Make `DropBlock` really work
theabhirath Jun 28, 2022
3be1d81
Construct the stem outside and pass it into `resnet`
theabhirath Jun 29, 2022
16cbcd0
Add ResNeXt back
theabhirath Jun 29, 2022
e5294ec
Enable CI for Windows
theabhirath Jun 30, 2022
a439bdf
Add more general implementation of SE layer
theabhirath Jun 29, 2022
441ade8
Tweaks III + Some more docs
theabhirath Jul 1, 2022
5d059f5
Fix `DropBlock` on the GPU
theabhirath Jul 3, 2022
226e96a
Add `SEResNet` and `SEResNeXt`
theabhirath Jul 3, 2022
3a4ffbf
More docs, more tweaks
theabhirath Jul 4, 2022
2f755cf
More aggressive GC
theabhirath Jul 8, 2022
5ba4b84
Tweaks don't stop
theabhirath Jul 9, 2022
aaf2abb
Reorganisation and formatting
theabhirath Jul 9, 2022
326f36c
Refactor shortcut connections
theabhirath Jul 9, 2022
4e01443
Generalise `resnet` further
theabhirath Jul 10, 2022
e8d3488
Documentation
theabhirath Jul 10, 2022
92ed4fa
Add classifier and backbone methods
theabhirath Jul 12, 2022
96a7d31
Refactor of resnet core
theabhirath Jul 17, 2022
9540299
Add `DropBlock`
theabhirath Jun 16, 2022
588d703
Initial commit for new ResNet API
theabhirath Jun 21, 2022
2a5d0cc
Cleanup
theabhirath Jun 22, 2022
07c1e95
Get some stuff to work
theabhirath Jun 23, 2022
2e88201
Tweaks - I
theabhirath Jun 23, 2022
01eaa8b
Make pretrain condition explicit
theabhirath Jun 25, 2022
546b131
More declarative interface for ResNet
theabhirath Jun 28, 2022
3f45f27
Make `DropBlock` really work
theabhirath Jun 28, 2022
f373f45
Construct the stem outside and pass it into `resnet`
theabhirath Jun 29, 2022
51d0757
Add ResNeXt back
theabhirath Jun 29, 2022
106f260
Add more general implementation of SE layer
theabhirath Jun 29, 2022
7147309
Tweaks III + Some more docs
theabhirath Jul 1, 2022
7ed20d4
Fix `DropBlock` on the GPU
theabhirath Jul 3, 2022
f0051b7
Add `SEResNet` and `SEResNeXt`
theabhirath Jul 3, 2022
e5d2295
More docs, more tweaks
theabhirath Jul 4, 2022
4a91fc4
More aggressive GC
theabhirath Jul 8, 2022
cf538bb
Tweaks don't stop
theabhirath Jul 9, 2022
5be45ef
Reorganisation and formatting
theabhirath Jul 9, 2022
1e509df
Refactor shortcut connections
theabhirath Jul 9, 2022
e4930f1
Generalise `resnet` further
theabhirath Jul 10, 2022
80bdcde
Documentation
theabhirath Jul 10, 2022
ab37901
Add classifier and backbone methods
theabhirath Jul 12, 2022
68abbb7
Refactor of resnet core
theabhirath Jul 17, 2022
7ad362b
Refactor of resnet core II
theabhirath Jul 22, 2022
93fb500
Merge branch 'resnet-plus' of https://github.com/theabhirath/Metalhea…
theabhirath Jul 22, 2022
13ed5ac
Allow `prenorm`
theabhirath Jul 22, 2022
6c005d3
Cleanup
theabhirath Jul 23, 2022
bd443f1
Reorganisation
theabhirath Jul 23, 2022
ce1da45
Reorganisation
theabhirath Jul 23, 2022
ed57c8f
Merge branch 'resnet-plus' of https://github.com/theabhirath/Metalhea…
theabhirath Jul 27, 2022
8c9f73f
Remove templating for now
theabhirath Jul 27, 2022
ca53acb
Fix tests, hopefully
theabhirath Jul 28, 2022
54ea529
Revert "Remove templating for now"
theabhirath Jul 28, 2022
541fabd
Merge branch 'master' into resnet-plus
theabhirath Jul 29, 2022
cff07cb
MobileNet tweaks
theabhirath Jul 29, 2022
674b27e
Make templating work again
theabhirath Jul 29, 2022
aa2a9ef
Tests just don't fix themselves
theabhirath Jul 29, 2022
b143b95
Fifth refactor is a charm
theabhirath Jul 29, 2022
fc74aa1
Cleanup - docs and code
theabhirath Jul 29, 2022
99eb25a
Make all config dicts `const` and capitalise
theabhirath Jul 29, 2022
73131bf
Formatting, and some tweaks
theabhirath Jul 30, 2022
73df024
Add WideResNet
theabhirath Jul 30, 2022
72cd4a9
Don't use globals
theabhirath Aug 2, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,15 @@ jobs:
- x64
suite:
- '["AlexNet", "VGG"]'
- '["GoogLeNet", "SqueezeNet"]'
- '["EfficientNet", "MobileNet"]'
- '[r"/*/ResNet*", "ResNeXt"]'
- 'r"/*/Inception/Inceptionv*"'
- '["InceptionResNetv2", "Xception"]'
- '["GoogLeNet", "SqueezeNet", "MobileNet"]'
- '["EfficientNet"]'
- 'r"/*/ResNet*"'
- '[r"ResNeXt", r"SEResNet"]'
- '"Inception"'
- '"DenseNet"'
- '"ConvNeXt"'
- '"ConvMixer"'
- '"ViT"'
- '"Other"'
- '["ConvNeXt", "ConvMixer"]'
- 'r"ViTs"'
- 'r"Mixers"'
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ version = "0.8.0-DEV"
[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

Expand Down
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Pkg

Pkg.develop(path = "..")
Pkg.develop(; path = "..")

using Publish
using Artifacts, LazyArtifacts
Expand All @@ -13,5 +13,5 @@ p = Publish.Project(Metalhead)

function build_and_deploy(label)
rm(label; recursive = true, force = true)
deploy(Metalhead; root = "/Metalhead.jl", label = label)
return deploy(Metalhead; root = "/Metalhead.jl", label = label)
end
2 changes: 1 addition & 1 deletion docs/serve.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Pkg

Pkg.develop(path = "..")
Pkg.develop(; path = "..")

using Revise
using Publish
Expand Down
37 changes: 27 additions & 10 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using BSON
using Artifacts, LazyArtifacts
using Statistics
using MLUtils
using PartialFunctions
using Random

import Functors
Expand All @@ -20,38 +21,54 @@ using .Layers
# CNN models
include("convnets/alexnet.jl")
include("convnets/vgg.jl")
include("convnets/inception.jl")
include("convnets/googlenet.jl")
include("convnets/resnet.jl")
include("convnets/resnext.jl")
## ResNets
include("convnets/resnets/core.jl")
include("convnets/resnets/resnet.jl")
include("convnets/resnets/resnext.jl")
include("convnets/resnets/seresnet.jl")
## Inceptions
include("convnets/inception/googlenet.jl")
include("convnets/inception/inceptionv3.jl")
include("convnets/inception/inceptionv4.jl")
include("convnets/inception/inceptionresnetv2.jl")
include("convnets/inception/xception.jl")
## MobileNets
include("convnets/mobilenet/mobilenetv1.jl")
include("convnets/mobilenet/mobilenetv2.jl")
include("convnets/mobilenet/mobilenetv3.jl")
## Others
include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/mobilenet.jl")
include("convnets/efficientnet.jl")
include("convnets/convnext.jl")
include("convnets/convmixer.jl")

# Other models
include("other/mlpmixer.jl")
# Mixers
include("mixers/core.jl")
include("mixers/mlpmixer.jl")
include("mixers/resmlp.jl")
include("mixers/gmlp.jl")

# ViT-based models
# ViTs
include("vit-based/vit.jl")

# Load pretrained weights
include("pretrain.jl")

export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
WideResNet, SEResNet, SEResNeXt,
MLPMixer, ResMLP, gMLP,
ViT,
ConvMixer, ConvNeXt

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet,
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3,
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end
Expand Down
8 changes: 4 additions & 4 deletions src/convnets/alexnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ function alexnet(; nclasses = 1000)
Dropout(0.5),
Dense(4096, 4096, relu),
Dense(4096, nclasses)))

return layers
end

Expand All @@ -46,15 +45,16 @@ See also [`alexnet`](#).
struct AlexNet
layers::Any
end
@functor AlexNet

function AlexNet(; pretrain = false, nclasses = 1000)
layers = alexnet(; nclasses = nclasses)
pretrain && loadpretrain!(layers, "AlexNet")
if pretrain
loadpretrain!(layers, "AlexNet")
end
return AlexNet(layers)
end

@functor AlexNet

(m::AlexNet)(x) = m.layers(x)

backbone(m::AlexNet) = m.layers[1]
Expand Down
44 changes: 23 additions & 21 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,34 @@ Creates a ConvMixer model.

- `planes`: number of planes in the output of each block
- `depth`: number of layers
- `inchannels`: The number of channels in the input. The default value is 3.
- `inchannels`: The number of channels in the input.
- `kernel_size`: kernel size of the convolutional layers
- `patch_size`: size of the patches
- `activation`: activation function used after the convolutional layers
- `nclasses`: number of classes in the output
"""
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
preact = true, groups = planes,
pad = SamePad())), +),
conv_bn((1, 1), planes, planes, activation; preact = true)...)
stem = conv_norm(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
preact = true, groups = planes,
pad = SamePad())), +),
conv_norm((1, 1), planes, planes, activation; preact = true)...)
for _ in 1:depth]
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
return Chain(Chain(stem..., Chain(blocks)), head)
end

convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9),
:patch_size => (7, 7)),
:small => Dict(:planes => 768, :depth => 32, :kernel_size => (7, 7),
:patch_size => (7, 7)),
:large => Dict(:planes => 1024, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)))
const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)),
:small => Dict(:planes => 768, :depth => 32,
:kernel_size => (7, 7),
:patch_size => (7, 7)),
:large => Dict(:planes => 1024, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)))

"""
ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
Expand All @@ -45,26 +47,26 @@ Creates a ConvMixer model.
# Arguments

- `mode`: the mode of the model, either `:base`, `:small` or `:large`
- `inchannels`: The number of channels in the input. The default value is 3.
- `inchannels`: The number of channels in the input.
- `activation`: activation function used after the convolutional layers
- `nclasses`: number of classes in the output
"""
struct ConvMixer
layers::Any
end
@functor ConvMixer

function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
planes = convmixer_config[mode][:planes]
depth = convmixer_config[mode][:depth]
kernel_size = convmixer_config[mode][:kernel_size]
patch_size = convmixer_config[mode][:patch_size]
_checkconfig(mode, keys(CONVMIXER_CONFIGS))
planes = CONVMIXER_CONFIGS[mode][:planes]
depth = CONVMIXER_CONFIGS[mode][:depth]
kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size]
patch_size = CONVMIXER_CONFIGS[mode][:patch_size]
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation,
nclasses)
return ConvMixer(layers)
end

@functor ConvMixer

(m::ConvMixer)(x) = m.layers(x)

backbone(m::ConvMixer) = m.layers[1]
Expand Down
47 changes: 18 additions & 29 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Creates a single block of ConvNeXt.
([reference](https://arxiv.org/abs/2201.03545))

# Arguments:
# Arguments

- `planes`: number of input channels.
- `drop_path_rate`: Stochastic depth rate.
Expand All @@ -27,7 +27,7 @@ end
Creates the layers for a ConvNeXt model.
([reference](https://arxiv.org/abs/2201.03545))

# Arguments:
# Arguments

- `inchannels`: number of input channels.
- `depths`: list with configuration for depth of each block
Expand All @@ -39,60 +39,53 @@ Creates the layers for a ConvNeXt model.
"""
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
nclasses = 1000)
@assert length(depths)==length(planes) "`planes` should have exactly one value for each block"

@assert length(depths) == length(planes)
"`planes` should have exactly one value for each block"
downsample_layers = []
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
ChannelLayerNorm(planes[1]; ϵ = 1.0f-6))
ChannelLayerNorm(planes[1]))
push!(downsample_layers, stem)
for m in 1:(length(depths) - 1)
downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1.0f-6),
downsample_layer = Chain(ChannelLayerNorm(planes[m]),
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
push!(downsample_layers, downsample_layer)
end

stages = []
dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths))
dp_rates = linear_scheduler(drop_path_rate; depth = sum(depths))
cur = 0
for i in 1:length(depths)
for i in eachindex(depths)
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
cur += depths[i]
end

backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
head = Chain(GlobalMeanPool(),
MLUtils.flatten,
LayerNorm(planes[end]),
Dense(planes[end], nclasses))

return Chain(Chain(backbone), head)
end

# Configurations for ConvNeXt models
convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3],
:planes => [96, 192, 384, 768]),
:small => Dict(:depths => [3, 3, 27, 3],
:planes => [96, 192, 384, 768]),
:base => Dict(:depths => [3, 3, 27, 3],
:planes => [128, 256, 512, 1024]),
:large => Dict(:depths => [3, 3, 27, 3],
:planes => [192, 384, 768, 1536]),
:xlarge => Dict(:depths => [3, 3, 27, 3],
:planes => [256, 512, 1024, 2048]))
const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
:base => ([3, 3, 27, 3], [128, 256, 512, 1024]),
:large => ([3, 3, 27, 3], [192, 384, 768, 1536]),
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))

struct ConvNeXt
layers::Any
end
@functor ConvNeXt

"""
ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000)

Creates a ConvNeXt model.
([reference](https://arxiv.org/abs/2201.03545))

# Arguments:
# Arguments

- `inchannels`: The number of channels in the input. The default value is 3.
- `inchannels`: The number of channels in the input.
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
- `nclasses`: number of output classes
Expand All @@ -101,16 +94,12 @@ See also [`Metalhead.convnext`](#).
"""
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
nclasses = 1000)
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
depths = convnext_configs[mode][:depths]
planes = convnext_configs[mode][:planes]
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)
_checkconfig(mode, keys(CONVNEXT_CONFIGS))
layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, drop_path_rate, λ, nclasses)
return ConvNeXt(layers)
end

(m::ConvNeXt)(x) = m.layers(x)

@functor ConvNeXt

backbone(m::ConvNeXt) = m.layers[1]
classifier(m::ConvNeXt) = m.layers[2]
Loading