diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 374f28615..78073c154 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -56,14 +56,12 @@ include("vit-based/vit.jl") include("pretrain.jl") export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, - ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, + ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, + WideResNet, ResNeXt, SEResNet, SEResNeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, - WideResNet, SEResNet, SEResNeXt, - MLPMixer, ResMLP, gMLP, - ViT, - ConvMixer, ConvNeXt + MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt, diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index 6b384f80c..3c713839e 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -1,11 +1,12 @@ """ - alexnet(; nclasses::Integer = 1000) + alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000) Create an AlexNet model ([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)). # Arguments + - `inchannels`: The number of input channels. - `nclasses`: the number of output classes """ function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000) @@ -27,19 +28,23 @@ function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000) end """ - AlexNet(; pretrain::Bool = false, nclasses::Integer = 1000) + AlexNet(; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) Create a `AlexNet`. -See also [`alexnet`](#). - -!!! warning - - `AlexNet` does not currently support pretrained weights. +([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)). # Arguments - `pretrain`: set to `true` to load pre-trained weights for ImageNet + - `inchannels`: The number of input channels. - `nclasses`: the number of output classes + +!!! warning + + `AlexNet` does not currently support pretrained weights. + +See also [`alexnet`](#). """ struct AlexNet layers::Any diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index efde886cb..309989d2d 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -26,28 +26,28 @@ function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9), pad = SamePad())), +), conv_norm((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth] - return Chain(Chain(stem..., Chain(blocks)), create_classifier(planes, nclasses)) + return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses)) end -const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20, - :kernel_size => (9, 9), - :patch_size => (7, 7)), - :small => Dict(:planes => 768, :depth => 32, - :kernel_size => (7, 7), - :patch_size => (7, 7)), - :large => Dict(:planes => 1024, :depth => 20, - :kernel_size => (9, 9), - :patch_size => (7, 7))) +const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20), + (kernel_size = (9, 9), + patch_size = (7, 7))), + :small => ((768, 32), + (kernel_size = (7, 7), + patch_size = (7, 7))), + :large => ((1024, 20), + (kernel_size = (9, 9), + patch_size = (7, 7)))) """ - ConvMixer(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) + ConvMixer(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) Creates a ConvMixer model. ([reference](https://arxiv.org/abs/2201.09792)) # Arguments - - `mode`: the mode of the model, either `:base`, `:small` or `:large` + - `config`: the size of the model, either `:base`, `:small` or `:large` - `inchannels`: The number of channels in the input. - `nclasses`: number of classes in the output """ @@ -56,13 +56,10 @@ struct ConvMixer end @functor ConvMixer -function ConvMixer(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(mode, keys(CONVMIXER_CONFIGS)) - planes = CONVMIXER_CONFIGS[mode][:planes] - depth = CONVMIXER_CONFIGS[mode][:depth] - kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size] - patch_size = CONVMIXER_CONFIGS[mode][:patch_size] - layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, nclasses) +function ConvMixer(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, keys(CONVMIXER_CONFIGS)) + layers = convmixer(CONVMIXER_CONFIGS[config][1]...; CONVMIXER_CONFIGS[config][2]..., + inchannels, nclasses) return ConvMixer(layers) end diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 7bb265c24..040a409ab 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -22,7 +22,7 @@ function convnextblock(planes::Integer, drop_path_rate = 0.0, layerscale_init = end """ - convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer}; + convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer}; drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -31,27 +31,27 @@ Creates the layers for a ConvNeXt model. # Arguments - - `inchannels`: number of input channels. - `depths`: list with configuration for depth of each block - `planes`: list with configuration for number of output channels in each block - `drop_path_rate`: Stochastic depth rate. - `layerscale_init`: Initial value for [`LayerScale`](#) ([reference](https://arxiv.org/abs/2103.17239)) + - `inchannels`: number of input channels. - `nclasses`: number of output classes """ -function convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer}; +function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer}; drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3, nclasses::Integer = 1000) @assert length(depths) == length(planes) "`planes` should have exactly one value for each block" downsample_layers = [] - stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4), - ChannelLayerNorm(planes[1])) - push!(downsample_layers, stem) + push!(downsample_layers, + Chain(conv_norm((4, 4), inchannels => planes[1]; stride = 4, + norm_layer = ChannelLayerNorm)...)) for m in 1:(length(depths) - 1) - downsample_layer = Chain(ChannelLayerNorm(planes[m]), - Conv((2, 2), planes[m] => planes[m + 1]; stride = 2)) - push!(downsample_layers, downsample_layer) + push!(downsample_layers, + Chain(conv_norm((2, 2), planes[m] => planes[m + 1]; stride = 2, + norm_layer = ChannelLayerNorm, revnorm = true)...)) end stages = [] dp_rates = linear_scheduler(drop_path_rate; depth = sum(depths)) @@ -64,8 +64,7 @@ function convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer}; end backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages)))) classifier = Chain(GlobalMeanPool(), MLUtils.flatten, - LayerNorm(planes[end]), - Dense(planes[end], nclasses)) + LayerNorm(planes[end]), Dense(planes[end], nclasses)) return Chain(Chain(backbone...), classifier) end @@ -77,13 +76,14 @@ const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]), :xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048])) """ - ConvNeXt(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) + ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) Creates a ConvNeXt model. ([reference](https://arxiv.org/abs/2201.03545)) # Arguments + - `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`. - `inchannels`: The number of channels in the input. - `nclasses`: number of output classes @@ -94,9 +94,9 @@ struct ConvNeXt end @functor ConvNeXt -function ConvNeXt(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(mode, keys(CONVNEXT_CONFIGS)) - layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, nclasses) +function ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, keys(CONVNEXT_CONFIGS)) + layers = convnext(CONVNEXT_CONFIGS[config]...; inchannels, nclasses) return ConvNeXt(layers) end diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index b82f138fb..ab833bd41 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -99,7 +99,8 @@ Create a DenseNet model - `reduction`: the factor by which the number of feature maps is scaled across each transition - `nclasses`: the number of output classes """ -function densenet(nblocks::Vector{<:Integer}; growth_rate::Integer = 32, reduction = 0.5, +function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32, + reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; reduction, inchannels, nclasses) diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index 86ba9373f..91986fb92 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -22,8 +22,10 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). - `max_width`: maximum number of output channels before the fully connected classification blocks """ -function efficientnet(scalings, block_configs; max_width::Integer = 1280, - inchannels::Integer = 3, nclasses::Integer = 1000) +function efficientnet(scalings::NTuple{2, Real}, + block_configs::AbstractVector{NTuple{6, Int}}; + max_width::Integer = 1280, inchannels::Integer = 3, + nclasses::Integer = 1000) wscale, dscale = scalings scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) @@ -83,61 +85,32 @@ const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), :b8 => (672, (2.2, 3.6))) """ - EfficientNet(scalings, block_configs; max_width::Integer = 1280, - inchannels::Integer = 3, nclasses::Integer = 1000) + EfficientNet(config::Symbol; pretrain::Bool = false) Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). See also [`efficientnet`](#). # Arguments - - `scalings`: global width and depth scaling (given as a tuple) - - - `block_configs`: configuration for each inverted residual block, - given as a vector of tuples with elements: - - + `n`: number of block repetitions (will be scaled by global depth scaling) - + `k`: kernel size - + `s`: kernel stride - + `e`: expansion ratio - + `i`: block input channels (will be scaled by global width scaling) - + `o`: block output channels (will be scaled by global width scaling) - - `inchannels`: number of input channels - - `nclasses`: number of output classes - - `max_width`: maximum number of output channels before the fully connected - classification blocks + - `config`: name of default configuration + (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) + - `pretrain`: set to `true` to load the pre-trained weights for ImageNet """ struct EfficientNet layers::Any end @functor EfficientNet -function EfficientNet(scalings, block_configs; max_width::Integer = 1280, - inchannels::Integer = 3, nclasses::Integer = 1000) - layers = efficientnet(scalings, block_configs; inchannels, nclasses, max_width) - return EfficientNet(layers) +function EfficientNet(config::Symbol; pretrain::Bool = false) + _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) + model = efficientnet(EFFICIENTNET_GLOBAL_CONFIGS[config][2], EFFICIENTNET_BLOCK_CONFIGS) + if pretrain + loadpretrain!(model, string("efficientnet-", config)) + end + return model end (m::EfficientNet)(x) = m.layers(x) backbone(m::EfficientNet) = m.layers[1] classifier(m::EfficientNet) = m.layers[2] - -""" - EfficientNet(name::Symbol; pretrain::Bool = false) - -Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). -See also [`efficientnet`](#). - -# Arguments - - - `name`: name of default configuration - (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) - - `pretrain`: set to `true` to load the pre-trained weights for ImageNet -""" -function EfficientNet(name::Symbol; pretrain::Bool = false) - _checkconfig(name, keys(EFFICIENTNET_GLOBAL_CONFIGS)) - model = EfficientNet(EFFICIENTNET_GLOBAL_CONFIGS[name][2], EFFICIENTNET_BLOCK_CONFIGS) - pretrain && loadpretrain!(model, string("efficientnet-", name)) - return model -end diff --git a/src/convnets/inception/googlenet.jl b/src/convnets/inception/googlenet.jl index a72ba5e6c..11d4dd7d3 100644 --- a/src/convnets/inception/googlenet.jl +++ b/src/convnets/inception/googlenet.jl @@ -53,8 +53,7 @@ function googlenet(; inchannels::Integer = 3, nclasses::Integer = 1000) MaxPool((3, 3); stride = 2, pad = 1), _inceptionblock(832, 256, 160, 320, 32, 128, 128), _inceptionblock(832, 384, 192, 384, 48, 128, 128)) - classifier = create_classifier(1024, nclasses; dropout_rate = 0.4) - return Chain(backbone, classifier) + return Chain(backbone, create_classifier(1024, nclasses; dropout_rate = 0.4)) end """ diff --git a/src/convnets/inception/inceptionresnetv2.jl b/src/convnets/inception/inceptionresnetv2.jl index 96b391b65..c2855191b 100644 --- a/src/convnets/inception/inceptionresnetv2.jl +++ b/src/convnets/inception/inceptionresnetv2.jl @@ -92,8 +92,7 @@ function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, [block8(0.20f0) for _ in 1:9]..., block8(; activation = relu), conv_norm((1, 1), 2080, 1536)...) - classifier = create_classifier(1536, nclasses; dropout_rate) - return Chain(backbone, classifier) + return Chain(backbone, create_classifier(1536, nclasses; dropout_rate)) end """ diff --git a/src/convnets/inception/inceptionv3.jl b/src/convnets/inception/inceptionv3.jl index 8a5e19849..e5083feb5 100644 --- a/src/convnets/inception/inceptionv3.jl +++ b/src/convnets/inception/inceptionv3.jl @@ -154,8 +154,7 @@ function inceptionv3(; inchannels::Integer = 3, nclasses::Integer = 1000) inceptionv3_d(768), inceptionv3_e(1280), inceptionv3_e(2048)) - classifier = create_classifier(2048, nclasses; dropout_rate = 0.2) - return Chain(backbone, classifier) + return Chain(backbone, create_classifier(2048, nclasses; dropout_rate = 0.2)) end """ diff --git a/src/convnets/inception/inceptionv4.jl b/src/convnets/inception/inceptionv4.jl index 8d4f00eb2..cd4971742 100644 --- a/src/convnets/inception/inceptionv4.jl +++ b/src/convnets/inception/inceptionv4.jl @@ -117,8 +117,7 @@ function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3, inceptionv4_c(), inceptionv4_c(), inceptionv4_c()) - classifier = create_classifier(1536, nclasses; dropout_rate) - return Chain(backbone, classifier) + return Chain(backbone, create_classifier(1536, nclasses; dropout_rate)) end """ diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl index 71a4efc15..1c97daddc 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inception/xception.jl @@ -45,15 +45,15 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end """ - xception(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000) + xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) # Arguments - - `inchannels`: number of input channels. - `dropout_rate`: rate of dropout in classifier head. + - `inchannels`: number of input channels. - `nclasses`: the number of output classes. """ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl index ca20b4a64..b6d9fe8ee 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -1,5 +1,5 @@ """ - mobilenetv1(width_mult::Number, config::Vector{<:Tuple}; activation = relu, + mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). @@ -19,11 +19,11 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). - `inchannels`: The number of input channels. The default value is 3. - `nclasses`: The number of output classes """ -function mobilenetv1(width_mult::Number, config::Vector{<:Tuple}; activation = relu, +function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] for (dw, outch, stride, nrepeats) in config - outch = Int(outch * width_mult) + outch = floor(Int, outch * width_mult) for _ in 1:nrepeats layer = dw ? depthwise_sep_conv_norm((3, 3), inchannels, outch, activation; @@ -76,7 +76,7 @@ struct MobileNetv1 end @functor MobileNetv1 -function MobileNetv1(width_mult::Number = 1; pretrain::Bool = false, +function MobileNetv1(width_mult::Real = 1; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = mobilenetv1(width_mult, MOBILENETV1_CONFIGS; inchannels, nclasses) if pretrain diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index 59e147829..84162e985 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -1,5 +1,5 @@ """ - mobilenetv2(width_mult::Number, configs::Vector{<:Tuple}; + mobilenetv2(width_mult::Real, configs::AbstractVector{<:Tuple}; max_width::Integer = 1280, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -22,7 +22,7 @@ Create a MobileNetv2 model. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: The number of output classes """ -function mobilenetv2(width_mult::Number, configs::Vector{<:Tuple}; +function mobilenetv2(width_mult::Real, configs::AbstractVector{<:Tuple}; max_width::Integer = 1280, inchannels::Integer = 3, nclasses::Integer = 1000) divisor = width_mult == 0.1 ? 4 : 8 @@ -83,10 +83,9 @@ struct MobileNetv2 end @functor MobileNetv2 -function MobileNetv2(width_mult::Number = 1; pretrain::Bool = false, +function MobileNetv2(width_mult::Real = 1; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = mobilenetv2(width_mult, MOBILENETV2_CONFIGS; inchannels, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv2")) if pretrain loadpretrain!(layers, string("MobileNetv2")) end diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 1c5e5825b..7d06ab14d 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -1,5 +1,5 @@ """ - mobilenetv3(width_mult::Number, configs::Vector{<:Tuple}; + mobilenetv3(width_mult::Real, configs::AbstractVector{<:Tuple}; max_width::Integer = 1024, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -24,7 +24,7 @@ Create a MobileNetv3 model. - `max_width`: The maximum number of feature maps in any layer of the network - `nclasses`: the number of output classes """ -function mobilenetv3(width_mult::Number, configs::Vector{<:Tuple}; +function mobilenetv3(width_mult::Real, configs::AbstractVector{<:Tuple}; max_width::Integer = 1024, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer @@ -91,7 +91,7 @@ const MOBILENETV3_CONFIGS = Dict(:small => [ ]) """ - MobileNetv3(mode::Symbol; width_mult::Number = 1, pretrain::Bool = false, + MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv3 model with the specified configuration. @@ -100,7 +100,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. # Arguments - - `mode`: :small or :large for the size of the model (see paper). + - `config`: :small or :large for the size of the model (see paper). - `width_mult`: Controls the number of output feature maps in each block (with 1.0 being the default in the paper; this is usually a value between 0.1 and 1.4) @@ -115,14 +115,14 @@ struct MobileNetv3 end @functor MobileNetv3 -function MobileNetv3(mode::Symbol; width_mult::Number = 1, pretrain::Bool = false, +function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(mode, [:small, :large]) - max_width = (mode == :large) ? 1280 : 1024 - layers = mobilenetv3(width_mult, MOBILENETV3_CONFIGS[mode]; inchannels, max_width, + _checkconfig(config, [:small, :large]) + max_width = (config == :large) ? 1280 : 1024 + layers = mobilenetv3(width_mult, MOBILENETV3_CONFIGS[config]; max_width, inchannels, nclasses) if pretrain - loadpretrain!(layers, string("MobileNetv3", mode)) + loadpretrain!(layers, string("MobileNetv3", config)) end return MobileNetv3(layers) end diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 940565f3a..79deadfb2 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -65,7 +65,7 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, norm_layer = BatchNorm, revnorm::Bool = false, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) - width = floor(Int, planes * (base_width / 64)) * cardinality + width = fld(planes * base_width, 64) * cardinality first_planes = width ÷ reduction_factor outplanes = planes * 4 conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, revnorm, @@ -190,15 +190,16 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, return Chain(conv1, bn1, stempool) end -function resnet_planes(block_repeats::Vector{<:Integer}) +function resnet_planes(block_repeats::AbstractVector{<:Integer}) return Iterators.flatten((64 * 2^(stage_idx - 1) for _ in 1:stages) for (stage_idx, stages) in enumerate(block_repeats)) end -function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64, - reduction_factor::Integer = 1, expansion::Integer = 1, - norm_layer = BatchNorm, revnorm::Bool = false, - activation = relu, attn_fn = planes -> identity, +function basicblock_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, reduction_factor::Integer = 1, + expansion::Integer = 1, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, drop_block_rate = 0.0, drop_path_rate = 0.0, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) @@ -228,11 +229,12 @@ function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer return get_layers end -function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64, - cardinality::Integer = 1, base_width::Integer = 64, - reduction_factor::Integer = 1, expansion::Integer = 4, - norm_layer = BatchNorm, revnorm::Bool = false, - activation = relu, attn_fn = planes -> identity, +function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, cardinality::Integer = 1, + base_width::Integer = 64, reduction_factor::Integer = 1, + expansion::Integer = 4, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, drop_block_rate = 0.0, drop_path_rate = 0.0, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) @@ -265,7 +267,7 @@ function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer return get_layers end -function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection) +function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connection) # Construct each stage stages = [] for (stage_idx, num_blocks) in enumerate(block_repeats) @@ -277,7 +279,8 @@ function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection) return Chain(stages...) end -function resnet(img_dims, stem, get_layers, block_repeats::Vector{<:Integer}, connection, +function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, + connection, classifier_fn) # Build stages of the ResNet stage_blocks = resnet_stages(get_layers, block_repeats, connection) @@ -288,7 +291,7 @@ function resnet(img_dims, stem, get_layers, block_repeats::Vector{<:Integer}, co return Chain(backbone, classifier) end -function resnet(block_type::Symbol, block_repeats::Vector{<:Integer}; +function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer}; downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity), cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64, reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256), diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index e685620a3..de232d9a3 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -40,7 +40,7 @@ Create VGG convolution layers - `batchnorm`: set to `true` to include batch normalization after each convolution - `inchannels`: number of input channels """ -function vgg_convolutional_layers(config::Vector{<:Tuple}, batchnorm::Bool, +function vgg_convolutional_layers(config::AbstractVector{<:Tuple}, batchnorm::Bool, inchannels::Integer) layers = [] ifilters = inchannels @@ -69,7 +69,7 @@ Create VGG classifier (fully connected) layers function vgg_classifier_layers(imsize::NTuple{3, <:Integer}, nclasses::Integer, fcsize::Integer, dropout_rate) return Chain(MLUtils.flatten, - Dense(Int(prod(imsize)), fcsize, relu), + Dense(prod(imsize), fcsize, relu), Dropout(dropout_rate), Dense(fcsize, fcsize, relu), Dropout(dropout_rate), @@ -107,10 +107,7 @@ const VGG_CONV_CONFIGS = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512 :D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)], :E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)]) -const VGG_CONFIGS = Dict(11 => :A, - 13 => :B, - 16 => :D, - 19 => :E) +const VGG_CONFIGS = Dict(11 => :A, 13 => :B, 16 => :D, 19 => :E) """ VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index e2276aa01..b8fd38165 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,5 +1,6 @@ """ - MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_dropout_rate = 0., proj_dropout_rate = 0.) + MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, + attn_dropout_rate = 0., proj_dropout_rate = 0.) Multi-head self-attention layer. diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 75b40708c..c355eac2f 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,7 +1,11 @@ """ - conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu; - norm_layer = BatchNorm, revnorm = false, preact = false, use_norm = true, - stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init]) + conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; + norm_layer = BatchNorm, revnorm::Bool = false, preact::Bool = false, + use_norm::Bool = true, stride::Integer = 1, pad::Integer = 0, + dilation::Integer = 1, groups::Integer = 1, [bias, weight, init]) + + conv_norm(kernel_size, inplanes => outplanes, activation = identity; + kwargs...) Create a convolution + batch normalization pair with activation. @@ -59,17 +63,21 @@ function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = ide end """ - depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu; - revnorm = false, use_norm = (true, true), - stride = 1, pad = 0, dilation = 1, [bias, weight, init]) + depthwise_sep_conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, + activation = relu; norm_layer = BatchNorm, + revnorm::Bool = false, stride::Integer = 1, + use_norm::NTuple{2, Bool} = (true, true), + pad::Integer = 0, dilation::Integer = 1, [bias, weight, init]) Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers: - a `kernel_size` depthwise convolution from `inplanes => inplanes` - - a batch norm layer + `activation` (if `use_norm[1] == true`; otherwise `activation` is applied to the convolution output) + - a (batch) normalisation layer + `activation` (if `use_norm[1] == true`; otherwise + `activation` is applied to the convolution output) - a `kernel_size` convolution from `inplanes => outplanes` - - a batch norm layer + `activation` (if `use_norm[2] == true`; otherwise `activation` is applied to the convolution output) + - a (batch) normalisation layer + `activation` (if `use_norm[2] == true`; otherwise + `activation` is applied to the convolution output) See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). @@ -80,7 +88,8 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `outplanes`: number of output feature maps - `activation`: the activation function for the final layer - `revnorm`: set to `true` to place the batch norm before the convolution - - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and second convolution + - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and + second convolution - `stride`: stride of the first convolution kernel - `pad`: padding of the first convolution kernel - `dilation`: dilation of the first convolution kernel @@ -88,9 +97,8 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). """ function depthwise_sep_conv_norm(kernel_size, inplanes::Integer, outplanes::Integer, activation = relu; norm_layer = BatchNorm, - revnorm::Bool = false, - use_norm::NTuple{2, Bool} = (true, true), - stride::Integer = 1, kwargs...) + revnorm::Bool = false, stride::Integer = 1, + use_norm::NTuple{2, Bool} = (true, true), kwargs...) return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; norm_layer, revnorm, use_norm = use_norm[1], stride, groups = inplanes, kwargs...), @@ -135,9 +143,9 @@ function invertedresidual(kernel_size, inplanes::Integer, hidden_planes::Integer end function invertedresidual(kernel_size, inplanes::Integer, outplanes::Integer, - activation = relu; stride::Integer, expansion, + activation = relu; stride::Integer, expansion::Real, reduction::Union{Nothing, Integer} = nothing) - hidden_planes = Int(inplanes * expansion) + hidden_planes = floor(Int, inplanes * expansion) return invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activation; stride, reduction) end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index f823d5c22..31c06c07a 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -50,6 +50,23 @@ end # Dispatch for CPU dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) +""" + DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, + rng = rng_from_array()) + +The `DropBlock` layer. While training, it zeroes out continguous regions of +size `block_size` in the input. During inference, it simply returns the input `x`. +((reference)[https://arxiv.org/abs/1810.12890]) + +# Arguments + + - `drop_block_prob`: probability of dropping a block + - `block_size`: size of the block to drop + - `gamma_scale`: multiplicative factor for `gamma` used. For the calculation of gamma, + refer to [the paper](https://arxiv.org/abs/1810.12890). + - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only + supported on the CPU. +""" mutable struct DropBlock{F, R <: AbstractRNG} drop_block_prob::F block_size::Integer @@ -84,23 +101,6 @@ function Flux.testmode!(m::DropBlock, mode = true) return (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) end -""" - DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, - rng = rng_from_array()) - -The `DropBlock` layer. While training, it zeroes out continguous regions of -size `block_size` in the input. During inference, it simply returns the input `x`. -((reference)[https://arxiv.org/abs/1810.12890]) - -# Arguments - - - `drop_block_prob`: probability of dropping a block - - `block_size`: size of the block to drop - - `gamma_scale`: multiplicative factor for `gamma` used. For the calculation of gamma, - refer to [the paper](https://arxiv.org/abs/1810.12890). - - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only - supported on the CPU. -""" function DropBlock(drop_block_prob = 0.1, block_size::Integer = 7, gamma_scale = 1.0, rng = rng_from_array()) if drop_block_prob == 0.0 diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index 560ac074d..cb9b8378c 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -23,10 +23,8 @@ function PatchEmbedding(imsize::Dims{2} = (224, 224); inchannels::Integer = 3, norm_layer = planes -> identity, flatten::Bool = true) im_height, im_width = imsize patch_height, patch_width = patch_size - @assert (im_height % patch_height == 0) && (im_width % patch_width == 0) "Image dimensions must be divisible by the patch size." - return Chain(Conv(patch_size, inchannels => embedplanes; stride = patch_size), flatten ? _flatten_spatial : identity, norm_layer(embedplanes)) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 049c06451..60447ddea 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -4,12 +4,13 @@ A type of adaptive pooling layer which uses both mean and max pooling and combines them to produce a single output. Note that this is equivalent to -`Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size))` +`Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size))`. +When `connection` is not specified, it defaults to `+`. # Arguments - - `output_size`: The size of the output after pooling. - `connection`: The connection type to use. + - `output_size`: The size of the output after pooling. """ function AdaptiveMeanMaxPool(connection, output_size::Tuple = (1, 1)) return Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)) diff --git a/src/mixers/core.jl b/src/mixers/core.jl index f08a5f5d5..81f18b6ff 100644 --- a/src/mixers/core.jl +++ b/src/mixers/core.jl @@ -1,7 +1,7 @@ """ mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels::Integer = 3, norm_layer = LayerNorm, patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., - depth = 12, nclasses::Integer = 1000, kwargs...) + depth::Integer = 12, nclasses::Integer = 1000, kwargs...) Creates a model with the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)). @@ -23,7 +23,8 @@ Creates a model with the MLPMixer architecture. """ function mlpmixer(block, imsize::Dims{2} = (224, 224); norm_layer = LayerNorm, patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0.0, - depth = 12, inchannels::Integer = 3, nclasses::Integer = 1000, kwargs...) + depth::Integer = 12, inchannels::Integer = 3, nclasses::Integer = 1000, + kwargs...) npatches = prod(imsize .÷ patch_size) dp_rates = linear_scheduler(drop_path_rate; depth) layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), @@ -35,7 +36,7 @@ function mlpmixer(block, imsize::Dims{2} = (224, 224); norm_layer = LayerNorm, end # Configurations for MLPMixer models -const MIXER_CONFIGS = Dict(:small => Dict(:depth => 8, :planes => 512), - :base => Dict(:depth => 12, :planes => 768), - :large => Dict(:depth => 24, :planes => 1024), - :huge => Dict(:depth => 32, :planes => 1280)) +const MIXER_CONFIGS = Dict(:small => (depth = 8, planes = 512), + :base => (depth = 12, planes = 768), + :large => (depth = 24, planes = 1024), + :huge => (depth = 32, planes = 1280)) diff --git a/src/mixers/gmlp.jl b/src/mixers/gmlp.jl index df4a52b70..ab89baadc 100644 --- a/src/mixers/gmlp.jl +++ b/src/mixers/gmlp.jl @@ -63,7 +63,7 @@ Creates a feedforward block based on the gMLP model architecture described in th function spatial_gating_block(planes::Integer, npatches::Integer; mlp_ratio = 4.0, norm_layer = LayerNorm, mlp_layer = gated_mlp_block, dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) - channelplanes = Int(mlp_ratio * planes) + channelplanes = floor(Int, mlp_ratio * planes) sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) return SkipConnection(Chain(norm_layer(planes), mlp_layer(sgu, planes, channelplanes; activation, @@ -72,7 +72,7 @@ function spatial_gating_block(planes::Integer, npatches::Integer; mlp_ratio = 4. end """ - gMLP(size::Symbol; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), + gMLP(config::Symbol; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), inchannels::Integer = 3, nclasses::Integer = 1000) Creates a model with the gMLP architecture. @@ -80,7 +80,7 @@ Creates a model with the gMLP architecture. # Arguments - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `config`: the size of the model - one of `small`, `base`, `large` or `huge` - `patch_size`: the size of the patches - `imsize`: the size of the input image - `inchannels`: the number of input channels @@ -93,13 +93,11 @@ struct gMLP end @functor gMLP -function gMLP(size::Symbol; imsize::Dims{2} = (224, 224), patch_size::Dims{2} = (16, 16), +function gMLP(config::Symbol; imsize::Dims{2} = (224, 224), patch_size::Dims{2} = (16, 16), inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(size, keys(MIXER_CONFIGS)) - depth = MIXER_CONFIGS[size][:depth] - embedplanes = MIXER_CONFIGS[size][:planes] + _checkconfig(config, keys(MIXER_CONFIGS)) layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, patch_size, - embedplanes, depth, inchannels, nclasses) + MIXER_CONFIGS[config]..., inchannels, nclasses) return gMLP(layers) end diff --git a/src/mixers/mlpmixer.jl b/src/mixers/mlpmixer.jl index 90a6aaebb..9cfa0c8b6 100644 --- a/src/mixers/mlpmixer.jl +++ b/src/mixers/mlpmixer.jl @@ -34,7 +34,7 @@ function mixerblock(planes::Integer, npatches::Integer; mlp_layer = mlp_block, end """ - MLPMixer(size::Symbol; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), + MLPMixer(config::Symbol; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), inchannels::Integer = 3, nclasses::Integer = 1000) Creates a model with the MLPMixer architecture. @@ -42,7 +42,7 @@ Creates a model with the MLPMixer architecture. # Arguments - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `config`: the size of the model - one of `small`, `base`, `large` or `huge` - `patch_size`: the size of the patches - `imsize`: the size of the input image - `drop_path_rate`: Stochastic depth rate @@ -56,13 +56,11 @@ struct MLPMixer end @functor MLPMixer -function MLPMixer(size::Symbol; imsize::Dims{2} = (224, 224), +function MLPMixer(config::Symbol; imsize::Dims{2} = (224, 224), patch_size::Dims{2} = (16, 16), inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(size, keys(MIXER_CONFIGS)) - depth = MIXER_CONFIGS[size][:depth] - embedplanes = MIXER_CONFIGS[size][:planes] - layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, inchannels, + _checkconfig(config, keys(MIXER_CONFIGS)) + layers = mlpmixer(mixerblock, imsize; patch_size, MIXER_CONFIGS[config]..., inchannels, nclasses) return MLPMixer(layers) end diff --git a/src/mixers/resmlp.jl b/src/mixers/resmlp.jl index f2c9ece15..21ad89d65 100644 --- a/src/mixers/resmlp.jl +++ b/src/mixers/resmlp.jl @@ -27,15 +27,14 @@ function resmixerblock(planes::Integer, npatches::Integer; mlp_layer = mlp_block LayerScale(planes, layerscale_init), DropPath(drop_path_rate)), +), SkipConnection(Chain(Flux.Scale(planes), - mlp_layer(planes, Int(mlp_ratio * planes); - dropout_rate, - activation), + mlp_layer(planes, floor(Int, mlp_ratio * planes); + dropout_rate, activation), LayerScale(planes, layerscale_init), DropPath(drop_path_rate)), +)) end """ - ResMLP(size::Symbol; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), + ResMLP(config::Symbol; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), inchannels::Integer = 3, nclasses::Integer = 1000) Creates a model with the ResMLP architecture. @@ -43,7 +42,7 @@ Creates a model with the ResMLP architecture. # Arguments - - `size`: the size of the model - one of `small`, `base`, `large` or `huge` + - `config`: the size of the model - one of `small`, `base`, `large` or `huge` - `patch_size`: the size of the patches - `imsize`: the size of the input image - `inchannels`: the number of input channels @@ -56,13 +55,12 @@ struct ResMLP end @functor ResMLP -function ResMLP(size::Symbol; imsize::Dims{2} = (224, 224), patch_size::Dims{2} = (16, 16), +function ResMLP(config::Symbol; imsize::Dims{2} = (224, 224), + patch_size::Dims{2} = (16, 16), inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(size, keys(MIXER_CONFIGS)) - depth = MIXER_CONFIGS[size][:depth] - embedplanes = MIXER_CONFIGS[size][:planes] - layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes, - depth, inchannels, nclasses) + _checkconfig(config, keys(MIXER_CONFIGS)) + layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, + MIXER_CONFIGS[config]..., inchannels, nclasses) return ResMLP(layers) end diff --git a/src/utilities.jl b/src/utilities.jl index 981777228..359010cfe 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -67,7 +67,7 @@ end Returns the dropout rates for a given depth using the linear scaling rule. """ -function linear_scheduler(drop_rate = 0.0; depth, start_value = 0.0) +function linear_scheduler(drop_rate = 0.0; depth::Integer, start_value = 0.0) return LinRange(start_value, drop_rate, depth) end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 6f145a4bb..099d00639 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -76,7 +76,7 @@ const VIT_CONFIGS = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3), mlp_ratio = 64 // 13)) """ - ViT(mode::Symbol = base; imsize::Dims{2} = (256, 256), inchannels::Integer = 3, + ViT(config::Symbol = base; imsize::Dims{2} = (256, 256), inchannels::Integer = 3, patch_size::Dims{2} = (16, 16), pool = :class, nclasses::Integer = 1000) Creates a Vision Transformer (ViT) model. @@ -84,7 +84,7 @@ Creates a Vision Transformer (ViT) model. # Arguments - - `mode`: the model configuration, one of + - `config`: the model configuration, one of `[:tiny, :small, :base, :large, :huge, :giant, :gigantic]` - `imsize`: image size - `inchannels`: number of input channels @@ -99,10 +99,10 @@ struct ViT end @functor ViT -function ViT(mode::Symbol; imsize::Dims{2} = (256, 256), patch_size::Dims{2} = (16, 16), +function ViT(config::Symbol; imsize::Dims{2} = (256, 256), patch_size::Dims{2} = (16, 16), inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(mode, keys(VIT_CONFIGS)) - layers = vit(imsize; inchannels, patch_size, nclasses, VIT_CONFIGS[mode]...) + _checkconfig(config, keys(VIT_CONFIGS)) + layers = vit(imsize; inchannels, patch_size, nclasses, VIT_CONFIGS[config]...) return ViT(layers) end diff --git a/test/convnets.jl b/test/convnets.jl index 35a745b87..a1bd20e19 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -123,16 +123,16 @@ end end @testset "EfficientNet" begin - @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8] + @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8] # preferred image resolution scaling - r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[name][1] + r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] x = rand(Float32, r, r, 3, 1) - m = EfficientNet(name) + m = EfficientNet(config) @test size(m(x)) == (1000, 1) - if (EfficientNet, name) in PRETRAINED_MODELS - @test acctest(EfficientNet(name, pretrain = true)) + if (EfficientNet, size) in PRETRAINED_MODELS + @test acctest(EfficientNet(config, pretrain = true)) else - @test_throws ArgumentError EfficientNet(name, pretrain = true) + @test_throws ArgumentError EfficientNet(config, pretrain = true) end @test gradtest(m, x) _gc() @@ -249,13 +249,13 @@ end end _gc() @testset "MobileNetv3" verbose = true begin - @testset for mode in [:small, :large] - m = MobileNetv3(mode) + @testset for config in [:small, :large] + m = MobileNetv3(config) @test size(m(x_224)) == (1000, 1) - if (MobileNetv3, mode) in PRETRAINED_MODELS - @test acctest(MobileNetv3(mode; pretrain = true)) + if (MobileNetv3, size) in PRETRAINED_MODELS + @test acctest(MobileNetv3(config; pretrain = true)) else - @test_throws ArgumentError MobileNetv3(mode; pretrain = true) + @test_throws ArgumentError MobileNetv3(config; pretrain = true) end @test gradtest(m, x_224) _gc() @@ -264,8 +264,8 @@ end end @testset "ConvNeXt" verbose = true begin - @testset for mode in [:small, :base, :large, :tiny, :xlarge] - m = ConvNeXt(mode) + @testset for config in [:small, :base, :large, :tiny, :xlarge] + m = ConvNeXt(config) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) _gc() @@ -273,8 +273,8 @@ end end @testset "ConvMixer" verbose = true begin - @testset for mode in [:small, :base, :large] - m = ConvMixer(mode) + @testset for config in [:small, :base, :large] + m = ConvMixer(config) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) _gc() diff --git a/test/mixers.jl b/test/mixers.jl index 51cdd736e..f22eaeb1c 100644 --- a/test/mixers.jl +++ b/test/mixers.jl @@ -1,7 +1,7 @@ @testset for model in [MLPMixer, ResMLP, gMLP] - @testset for mode in [:small, :base, :large] - m = model(mode) - @test size(m(x_224)) == (1000, 1) + @testset for config in [:small, :base, :large] + m = model(config) + @test config(m(x_224)) == (1000, 1) @test gradtest(m, x_224) _gc() end diff --git a/test/vits.jl b/test/vits.jl index fb9fd6b02..7561cfdb5 100644 --- a/test/vits.jl +++ b/test/vits.jl @@ -1,6 +1,6 @@ @testset "ViT" begin - for mode in [:tiny, :small, :base, :large, :huge] # :giant, :gigantic] - m = ViT(mode) + for config in [:tiny, :small, :base, :large, :huge] # :giant, :gigantic] + m = ViT(config) @test size(m(x_256)) == (1000, 1) @test gradtest(m, x_256) _gc()