-
-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding UNet Model #210
Merged
Merged
Adding UNet Model #210
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
ba54cf0
model implemented
shivance 11c50d9
adding documentation
shivance ca73586
ran juliaformatter
shivance 552a8fd
removed custom forward pass using Parallel
shivance c577aed
removing _random_normal
shivance fb642c4
incorporating suggested changes
shivance 7c7b1ee
Revert "ran juliaformatter"
shivance 99f07ad
adapting to fastai's unet impl
shivance fc756d9
undoing utilities formatting
shivance 60b082c
formatting + documentation + func signature
shivance 2f1cc6d
adding unit tests for unet
shivance 8d2ba2b
configuring CI
shivance 77a3148
configuring CI
shivance 8aebd14
Merge branch 'master' into unet
shivance 429096b
Update convnets.jl
shivance d761126
Update convnets.jl
shivance 1b5d2b7
updated test
shivance 354e3c4
minor fixes
shivance 6494be7
typing fix
shivance 2d68f61
Update src/utilities.jl
shivance 627480f
fixing ci
shivance 4012fb2
renaming:
shivance 016cef4
fixing test
shivance 6097c57
Update .github/workflows/CI.yml
shivance 98b4c30
Update src/convnets/unet.jl
shivance 54c334f
Update src/convnets/unet.jl
shivance 4fae8d6
incorporating suggestions
shivance 4735dff
minor change
shivance 3bebe5a
minor edit
shivance 65aa5e8
Update src/convnets/unet.jl
shivance File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
function pixel_shuffle_icnr(inplanes, outplanes; r = 2) | ||
return Chain(Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)...)), | ||
Flux.PixelShuffle(r)) | ||
end | ||
|
||
function unet_combine_layer(inplanes, outplanes) | ||
return Chain(Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1)...), | ||
Chain(basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)...)) | ||
end | ||
|
||
function unet_middle_block(inplanes) | ||
return Chain(Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1)...), | ||
Chain(basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)...)) | ||
end | ||
|
||
function unet_final_block(inplanes, outplanes) | ||
return Chain(basicblock(inplanes, inplanes; reduction_factor = 1), | ||
Chain(basic_conv_bn((1, 1), inplanes, outplanes)...)) | ||
end | ||
|
||
function unet_block(m_child, inplanes, midplanes, outplanes = 2 * inplanes) | ||
return Chain(SkipConnection(Chain(m_child, | ||
pixel_shuffle_icnr(midplanes, midplanes)), | ||
Parallel(cat_channels, identity, BatchNorm(inplanes))), | ||
relu, | ||
unet_combine_layer(inplanes + midplanes, outplanes)) | ||
end | ||
|
||
function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0, | ||
m_middle = _ -> (identity,)) | ||
isempty(layers) && return m_middle(sz[end - 1]) | ||
|
||
layer, layers = layers[1], layers[2:end] | ||
outsz = Flux.outputsize(layer, sz) | ||
does_downscale = sz[1] ÷ 2 == outsz[1] | ||
|
||
if !does_downscale | ||
return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...) | ||
elseif does_downscale && skip_upscale > 0 | ||
return Chain(layer, | ||
unetlayers(layers, outsz; skip_upscale = skip_upscale - 1, | ||
outplanes)...) | ||
else | ||
childunet = Chain(unetlayers(layers, outsz; skip_upscale)...) | ||
outsz = Flux.outputsize(childunet, outsz) | ||
|
||
inplanes = sz[end - 1] | ||
midplanes = outsz[end - 1] | ||
outplanes = isnothing(outplanes) ? inplanes : outplanes | ||
|
||
return unet_block(Chain(layer, childunet), | ||
inplanes, midplanes, outplanes) | ||
end | ||
end | ||
|
||
""" | ||
unet(encoder_backbone, imgdims, outplanes::Integer, final::Any = unet_final_block, | ||
fdownscale::Integer = 0) | ||
|
||
Creates a UNet model with specified convolutional backbone. | ||
Backbone of any Metalhead ResNet-like model can be used as encoder | ||
([reference](https://arxiv.org/abs/1505.04597)). | ||
|
||
# Arguments | ||
|
||
- `encoder_backbone`: The backbone layers of specified model to be used as encoder. | ||
For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed | ||
to instantiate a UNet with layers of resnet18 as encoder. | ||
- `inputsize`: size of input image | ||
- `outplanes`: number of output feature planes | ||
- `final`: final block as described in original paper | ||
- `fdownscale`: downscale factor | ||
""" | ||
function unet(encoder_backbone, imgdims, outplanes::Integer, | ||
final::Any = unet_final_block, fdownscale::Integer = 0) | ||
backbonelayers = collect(flatten_chains(encoder_backbone)) | ||
layers = unetlayers(backbonelayers, imgdims; m_middle = unet_middle_block, | ||
skip_upscale = fdownscale) | ||
|
||
outsz = Flux.outputsize(layers, imgdims) | ||
layers = Chain(layers, final(outsz[end - 1], outplanes)) | ||
|
||
return layers | ||
end | ||
|
||
""" | ||
UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, | ||
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) | ||
|
||
Creates a UNet model with an encoder built of specified backbone. By default it uses | ||
[`DenseNet`](@ref) backbone, but any ResNet-like Metalhead model can be used for the encoder. | ||
([reference](https://arxiv.org/abs/1505.04597)). | ||
|
||
# Arguments | ||
|
||
- `imsize`: size of input image | ||
- `inchannels`: number of channels in input image | ||
- `outplanes`: number of output feature planes. | ||
- `encoder_backbone`: The backbone layers of specified model to be used as encoder. | ||
For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of | ||
resnet18 as encoder. | ||
- `pretrain`: Whether to load the pre-trained weights for ImageNet | ||
|
||
!!! warning | ||
|
||
`UNet` does not currently support pretrained weights. | ||
|
||
See also [`Metalhead.unet`](@ref). | ||
""" | ||
struct UNet | ||
layers::Any | ||
end | ||
@functor UNet | ||
|
||
function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, | ||
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) | ||
layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes) | ||
|
||
if pretrain | ||
loadpretrain!(layers, string("UNet")) | ||
end | ||
return UNet(layers) | ||
end | ||
|
||
(m::UNet)(x::AbstractArray) = m.layers(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything in the UNet implementation that would prevent us from generalizing it to 1, 3 or more dimensions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to my own ignorance, which dimensions are spatial in the 1 and N>2 cases? Meaning which ones should be downscaled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as with 2D. Spatial dimensions x channels/features x batch size, so all but the last two assuming the usual memory layout.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shivance I think the point is that you don't need any changes other than dropping the type restriction to generalize to more dimensions.
But we'd want to have that in the test, so we can save it for another PR if you'd like.