Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding UNet Model #210

Merged
merged 30 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ba54cf0
model implemented
shivance Dec 27, 2022
11c50d9
adding documentation
shivance Dec 27, 2022
ca73586
ran juliaformatter
shivance Dec 28, 2022
552a8fd
removed custom forward pass using Parallel
shivance Jan 1, 2023
c577aed
removing _random_normal
shivance Jan 1, 2023
fb642c4
incorporating suggested changes
shivance Jan 2, 2023
7c7b1ee
Revert "ran juliaformatter"
shivance Jan 3, 2023
99f07ad
adapting to fastai's unet impl
shivance Jan 10, 2023
fc756d9
undoing utilities formatting
shivance Jan 10, 2023
60b082c
formatting + documentation + func signature
shivance Jan 10, 2023
2f1cc6d
adding unit tests for unet
shivance Jan 10, 2023
8d2ba2b
configuring CI
shivance Jan 10, 2023
77a3148
configuring CI
shivance Jan 10, 2023
8aebd14
Merge branch 'master' into unet
shivance Jan 10, 2023
429096b
Update convnets.jl
shivance Jan 10, 2023
d761126
Update convnets.jl
shivance Jan 10, 2023
1b5d2b7
updated test
shivance Jan 11, 2023
354e3c4
minor fixes
shivance Jan 12, 2023
6494be7
typing fix
shivance Jan 12, 2023
2d68f61
Update src/utilities.jl
shivance Jan 12, 2023
627480f
fixing ci
shivance Jan 12, 2023
4012fb2
renaming:
shivance Jan 16, 2023
016cef4
fixing test
shivance Jan 22, 2023
6097c57
Update .github/workflows/CI.yml
shivance Jan 22, 2023
98b4c30
Update src/convnets/unet.jl
shivance Jan 22, 2023
54c334f
Update src/convnets/unet.jl
shivance Jan 22, 2023
4fae8d6
incorporating suggestions
shivance Jan 22, 2023
4735dff
minor change
shivance Jan 22, 2023
3bebe5a
minor edit
shivance Jan 22, 2023
65aa5e8
Update src/convnets/unet.jl
shivance Jan 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
- '[r"Res2Net", r"Res2NeXt"]'
- '"Inception"'
- '"DenseNet"'
- '"UNet"'
- '["ConvNeXt", "ConvMixer"]'
- 'r"Mixers"'
- 'r"ViTs"'
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ julia> ]add Metalhead
| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](@ref) | N |
| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](@ref) | N |
| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](@ref) | N |
| [UNet](https://arxiv.org/abs/1505.04597v1) | [`UNet`](@ref) | N |

To contribute new models, see our [contributing docs](@ref Contributing-to-Metalhead.jl).

Expand Down
5 changes: 3 additions & 2 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/convnext.jl")
include("convnets/convmixer.jl")
include("convnets/unet.jl")

# Mixers
include("mixers/core.jl")
Expand All @@ -73,15 +74,15 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet,
EfficientNet, EfficientNetv2, ConvMixer, ConvNeXt,
MLPMixer, ResMLP, gMLP, ViT
MLPMixer, ResMLP, gMLP, ViT, UNet

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt,
:SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet,
:Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet,
:EfficientNet, :EfficientNetv2, :ConvMixer, :ConvNeXt,
:MLPMixer, :ResMLP, :gMLP, :ViT)
:MLPMixer, :ResMLP, :gMLP, :ViT, :UNet)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

Expand Down
125 changes: 125 additions & 0 deletions src/convnets/unet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
function pixel_shuffle_icnr(inplanes, outplanes; r = 2)
return Chain(Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)...)),
Flux.PixelShuffle(r))
end

function unet_combine_layer(inplanes, outplanes)
return Chain(Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1)...),
Chain(basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)...))
end

function unet_middle_block(inplanes)
return Chain(Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1)...),
Chain(basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)...))
end

function unet_final_block(inplanes, outplanes)
return Chain(basicblock(inplanes, inplanes; reduction_factor = 1),
Chain(basic_conv_bn((1, 1), inplanes, outplanes)...))
end

function unet_block(m_child, inplanes, midplanes, outplanes = 2 * inplanes)
return Chain(SkipConnection(Chain(m_child,
pixel_shuffle_icnr(midplanes, midplanes)),
Parallel(cat_channels, identity, BatchNorm(inplanes))),
relu,
unet_combine_layer(inplanes + midplanes, outplanes))
end

function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0,
m_middle = _ -> (identity,))
isempty(layers) && return m_middle(sz[end - 1])

layer, layers = layers[1], layers[2:end]
outsz = Flux.outputsize(layer, sz)
does_downscale = sz[1] ÷ 2 == outsz[1]

if !does_downscale
return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...)
elseif does_downscale && skip_upscale > 0
return Chain(layer,
unetlayers(layers, outsz; skip_upscale = skip_upscale - 1,
outplanes)...)
else
childunet = Chain(unetlayers(layers, outsz; skip_upscale)...)
outsz = Flux.outputsize(childunet, outsz)

inplanes = sz[end - 1]
midplanes = outsz[end - 1]
outplanes = isnothing(outplanes) ? inplanes : outplanes

return unet_block(Chain(layer, childunet),
inplanes, midplanes, outplanes)
end
end

"""
unet(encoder_backbone, imgdims, outplanes::Integer, final::Any = unet_final_block,
fdownscale::Integer = 0)

Creates a UNet model with specified convolutional backbone.
Backbone of any Metalhead ResNet-like model can be used as encoder
([reference](https://arxiv.org/abs/1505.04597)).

# Arguments

- `encoder_backbone`: The backbone layers of specified model to be used as encoder.
For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed
to instantiate a UNet with layers of resnet18 as encoder.
- `inputsize`: size of input image
- `outplanes`: number of output feature planes
- `final`: final block as described in original paper
- `fdownscale`: downscale factor
"""
function unet(encoder_backbone, imgdims, outplanes::Integer,
final::Any = unet_final_block, fdownscale::Integer = 0)
backbonelayers = collect(flatten_chains(encoder_backbone))
layers = unetlayers(backbonelayers, imgdims; m_middle = unet_middle_block,
skip_upscale = fdownscale)

outsz = Flux.outputsize(layers, imgdims)
layers = Chain(layers, final(outsz[end - 1], outplanes))

return layers
end

"""
UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)

Creates a UNet model with an encoder built of specified backbone. By default it uses
[`DenseNet`](@ref) backbone, but any ResNet-like Metalhead model can be used for the encoder.
([reference](https://arxiv.org/abs/1505.04597)).

# Arguments

- `imsize`: size of input image
- `inchannels`: number of channels in input image
- `outplanes`: number of output feature planes.
- `encoder_backbone`: The backbone layers of specified model to be used as encoder.
For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of
resnet18 as encoder.
- `pretrain`: Whether to load the pre-trained weights for ImageNet

!!! warning

`UNet` does not currently support pretrained weights.

See also [`Metalhead.unet`](@ref).
"""
struct UNet
layers::Any
end
@functor UNet

function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
function UNet(imsize::Dims = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,

Is there anything in the UNet implementation that would prevent us from generalizing it to 1, 3 or more dimensions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to my own ignorance, which dimensions are spatial in the 1 and N>2 cases? Meaning which ones should be downscaled?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as with 2D. Spatial dimensions x channels/features x batch size, so all but the last two assuming the usual memory layout.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shivance I think the point is that you don't need any changes other than dropping the type restriction to generalize to more dimensions.

But we'd want to have that in the test, so we can save it for another PR if you'd like.

encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes)

if pretrain
loadpretrain!(layers, string("UNet"))
end
return UNet(layers)
end

(m::UNet)(x::AbstractArray) = m.layers(x)
11 changes: 11 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,14 @@ linear_scheduler(drop_prob::Nothing; depth::Integer) = fill(drop_prob, depth)
function _checkconfig(config, configs)
@assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))."
end

"""
flatten_chains(m::Chain)
flatten_chains(m)

Convenience function for traversing nested layers of a Chain object and flatten them
into a single iterator.
"""
flatten_chains(m::Chain) = Iterators.flatten(flatten_chains(l) for l in m.layers)
flatten_chains(m) = (m,)

11 changes: 11 additions & 0 deletions test/convnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,14 @@ end
_gc()
end
end

@testset "UNet" begin
encoder = Metalhead.backbone(ResNet(18))
model = UNet((256, 256), 3, 10, encoder)
@test size(model(x_256)) == (256, 256, 10, 1)
@test gradtest(model, x_256)

model = UNet()
@test size(model(x_256)) == (256, 256, 3, 1)
_gc()
end