Skip to content

Commit

Permalink
Minor formatting tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Feb 4, 2022
1 parent eb45dee commit 329489f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export AlexNet,

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
:MobileNetv2, :MobileNetv3, :MLPMixer)
:MobileNetv2, :MobileNetv3, :MLPMixer)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

Expand Down
10 changes: 6 additions & 4 deletions src/other/mlpmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 1

classification_head = Chain(_seconddimmean, Dense(planes, nclasses))

return Chain(layers..., classification_head)
return Chain(Chain(layers...), classification_head)
end

struct MLPMixer
Expand All @@ -73,14 +73,16 @@ Creates a model with the MLPMixer architecture.
- nclasses: the number of classes in the output
"""
function MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512,
depth = 12, expansion_factor = 4, dropout = 0., pretrain = false, nclasses = 1000)
depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000)

layers = mlpmixer(imsize; inchannels, patch_size, planes, depth, expansion_factor, dropout,
nclasses)
pretrain && loadpretrain!(layers, string("MLPMixer"))
MLPMixer(layers)
end

@functor MLPMixer

(m::MLPMixer)(x) = m.layers(x)

@functor MLPMixer
backbone(m::MLPMixer) = m.layers[1]
classifier(m::MLPMixer) = m.layers[2:end]
4 changes: 2 additions & 2 deletions test/other.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ using Metalhead, Test
using Flux

@testset "MLPMixer" begin
@test size(MLPMixer()(rand(Float32, 256, 256, 3, 67))) == (1000, 67)
@test_skip gradtest(MLPMixer(), rand(Float32, 256, 256, 3, 67))
@test size(MLPMixer()(rand(Float32, 256, 256, 3, 2))) == (1000, 2)
@test_skip gradtest(MLPMixer(), rand(Float32, 256, 256, 3, 2))
end

0 comments on commit 329489f

Please sign in to comment.