diff --git a/Project.toml b/Project.toml index 0aaf992b7..d4053cbb8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Metalhead" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.8.0-DEV" +version = "0.8.0" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" diff --git a/docs/Project.toml b/docs/Project.toml index 6335fa0e9..2bb540e4c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,4 +5,4 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" -OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/docs/make.jl b/docs/make.jl index 72fb486c9..08d667b88 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,26 +1,41 @@ -using Documenter, Metalhead, Artifacts, LazyArtifacts, Images, OneHotArrays, DataAugmentation, Flux +using Documenter, Metalhead, Artifacts, LazyArtifacts, Images, DataAugmentation, Flux DocMeta.setdocmeta!(Metalhead, :DocTestSetup, :(using Metalhead); recursive = true) -makedocs(modules = [Metalhead, Artifacts, LazyArtifacts, Images, OneHotArrays, DataAugmentation, Flux], +makedocs(; modules = [Metalhead, Artifacts, LazyArtifacts, Images, DataAugmentation, Flux], sitename = "Metalhead.jl", doctest = false, pages = ["Home" => "index.md", - "Tutorials" => [ - "tutorials/quickstart.md", - ], - "Developer guide" => "contributing.md", - "API reference" => [ - "api/reference.md", - ], - ], - format = Documenter.HTML( - canonical = "https://fluxml.ai/Metalhead.jl/stable/", - # analytics = "UA-36890222-9", - assets = ["assets/flux.css"], - prettyurls = get(ENV, "CI", nothing) == "true"), - ) + "Tutorials" => [ + "tutorials/quickstart.md", + "tutorials/pretrained.md", + ], + "API reference" => [ + "Convolutional Neural Networks" => [ + "api/others.md", + "api/inception.md", + "api/resnet.md", + "api/densenet.md", + "api/hybrid.md", + "api/layers.md", + ], + "Mixers" => [ + "api/mixers.md", + ], + "Vision Transformers" => [ + "api/vit.md", + ], + "api/utilities.md" + ], + "How To" => [ + "howto/resnet.md", + ], + "Contributing to Metalhead" => "contributing.md", + ], + format = Documenter.HTML(; canonical = "https://fluxml.ai/Metalhead.jl/stable/", + # analytics = "UA-36890222-9", + assets = ["assets/flux.css"], + prettyurls = get(ENV, "CI", nothing) == "true")) -deploydocs(repo = "github.com/FluxML/Metalhead.jl.git", - target = "build", +deploydocs(; repo = "github.com/FluxML/Metalhead.jl.git", target = "build", push_preview = true) diff --git a/docs/src/api/densenet.md b/docs/src/api/densenet.md new file mode 100644 index 000000000..2c57a3897 --- /dev/null +++ b/docs/src/api/densenet.md @@ -0,0 +1,15 @@ +# DenseNet + +This is the API reference for the DenseNet model present in Metalhead.jl. + +## The higher level model + +```@docs +DenseNet +``` + +## The core function + +```@docs +Metalhead.densenet +``` diff --git a/docs/src/api/hybrid.md b/docs/src/api/hybrid.md new file mode 100644 index 000000000..adcb65822 --- /dev/null +++ b/docs/src/api/hybrid.md @@ -0,0 +1,17 @@ +# Hybrid CNN architectures + +These models are hybrid CNN architectures that borrow certain ideas from vision transformer models. + +## The higher-level model constructors + +```@docs +ConvMixer +ConvNeXt +``` + +## The mid-level functions + +```@docs +Metalhead.convmixer +Metalhead.convnext +``` \ No newline at end of file diff --git a/docs/src/api/inception.md b/docs/src/api/inception.md new file mode 100644 index 000000000..e7470f4ec --- /dev/null +++ b/docs/src/api/inception.md @@ -0,0 +1,23 @@ +# Inception models + +This is the API reference for the Inception family of models supported by Metalhead.jl. + +## The higher-level model constructors + +```@docs +GoogLeNet +Inceptionv3 +Inceptionv4 +InceptionResNetv2 +Xception +``` + +## The mid-level functions + +```@docs +Metalhead.googlenet +Metalhead.inceptionv3 +Metalhead.inceptionv4 +Metalhead.inceptionresnetv2 +Metalhead.xception +``` diff --git a/docs/src/api/layers.md b/docs/src/api/layers.md new file mode 100644 index 000000000..3b23693bb --- /dev/null +++ b/docs/src/api/layers.md @@ -0,0 +1,18 @@ +# Layers + +Metalhead also defines a module called `Layers` which contains some more modern layers that are not available in Flux. To use the functions defined in the `Layers` module, you need to import it. + +```julia +using Metalhead: Layers +``` + +This page contains the API reference for the `Layers` module. + +!!! warning + + The `Layers` module is still a work in progress. While we will endeavour to keep the API stable, we cannot guarantee that it will not change in the future. If you find any of the functions in this + module do not work as expected, please open an issue on GitHub. + +```@autodocs +Modules = [Metalhead.Layers] +``` diff --git a/docs/src/api/mixers.md b/docs/src/api/mixers.md new file mode 100644 index 000000000..42a19f28f --- /dev/null +++ b/docs/src/api/mixers.md @@ -0,0 +1,26 @@ +# MLPMixer-like models + +This is the API reference for the MLPMixer-like models supported by Metalhead.jl. + +## The higher-level model constructors + +```@docs +MLPMixer +ResMLP +gMLP +``` + +## The core MLPMixer function + +```@docs +Metalhead.mlpmixer +``` + +## The block functions + +```@docs +Metalhead.mixerblock +Metalhead.resmixerblock +Metalhead.SpatialGatingUnit +Metalhead.spatialgatingblock +``` \ No newline at end of file diff --git a/docs/src/api/others.md b/docs/src/api/others.md new file mode 100644 index 000000000..9cadfaee4 --- /dev/null +++ b/docs/src/api/others.md @@ -0,0 +1,19 @@ +# Other models + +This is the API reference for some of the other models supported by Metalhead.jl that do not fit into the other categories. + +## The higher-level model constructors + +```@docs +AlexNet +VGG +SqueezeNet +``` + +## The mid-level functions + +```@docs +Metalhead.alexnet +Metalhead.vgg +Metalhead.squeezenet +``` diff --git a/docs/src/api/reference.md b/docs/src/api/reference.md deleted file mode 100644 index 1699e81fe..000000000 --- a/docs/src/api/reference.md +++ /dev/null @@ -1,17 +0,0 @@ -# API Reference - -The API reference of `Metalhead.jl`. - -**Note**: This page is still in progress. - -```@autodocs -Modules = [Metalhead] -``` - -```@docs -Metalhead.create_classifier -Metalhead.squeeze_excite -Metalhead.LayerScale -Metalhead.DropBlock -Metalhead.StochasticDepth -``` diff --git a/docs/src/api/resnet.md b/docs/src/api/resnet.md new file mode 100644 index 000000000..e41766c62 --- /dev/null +++ b/docs/src/api/resnet.md @@ -0,0 +1,61 @@ +# ResNet-like models + +This is the API reference for the ResNet inspired model structures present in Metalhead.jl. + +## The higher-level model constructors + +```@docs +ResNet +WideResNet +ResNeXt +SEResNet +SEResNeXt +Res2Net +Res2NeXt +``` + +## The mid-level function + +```@docs +Metalhead.resnet +``` + +## Lower-level functions and builders + +### Block functions + +```@docs +Metalhead.basicblock +Metalhead.bottleneck +Metalhead.bottle2neck +``` + +### Downsampling functions + +```@docs +Metalhead.downsample_identity +Metalhead.downsample_conv +Metalhead.downsample_pool +``` + +### Block builders + +```@docs +Metalhead.basicblock_builder +Metalhead.bottleneck_builder +Metalhead.bottle2neck_builder +``` + +### Generic ResNet model builder + +```@docs +Metalhead.build_resnet +``` + +## Utility callbacks + +```@docs +Metalhead.resnet_planes +Metalhead.resnet_stride +Metalhead.resnet_stem +``` diff --git a/docs/src/api/utilities.md b/docs/src/api/utilities.md new file mode 100644 index 000000000..e80a17365 --- /dev/null +++ b/docs/src/api/utilities.md @@ -0,0 +1,10 @@ +# Utilities + +Metalhead provides some utility functions for making it easier to work with the models inside the library or to build new ones. The API reference for these is documented below. + +## `backbone` and `classifier` + +```@docs +backbone +classifier +``` \ No newline at end of file diff --git a/docs/src/api/vit.md b/docs/src/api/vit.md new file mode 100644 index 000000000..d1c303cc1 --- /dev/null +++ b/docs/src/api/vit.md @@ -0,0 +1,15 @@ +# Vision Transformer models + +This is the API reference for the Vision Transformer models supported by Metalhead.jl. + +## The higher-level model constructors + +```@docs +ViT +``` + +## The mid-level functions + +```@docs +Metalhead.vit +``` diff --git a/docs/src/contributing.md b/docs/src/contributing.md index f126d7bb8..a04fe9d41 100644 --- a/docs/src/contributing.md +++ b/docs/src/contributing.md @@ -1,4 +1,4 @@ -# Contributing to Metalhead.jl +# [Contribute to Metalhead.jl](@id contributing) We welcome contributions from anyone to Metalhead.jl! Thank you for taking the time to make our ecosystem better. @@ -16,7 +16,7 @@ To add a new model architecture to Metalhead.jl, you can [open a PR](https://git - reuse layers from Flux as much as possible (e.g. use `Parallel` before defining a `Bottleneck` struct) - adhere as closely as possible to a reference such as a published paper (i.e. the structure of your model should follow intuitively from the paper) -- use generic functional builders (e.g. [`Metalhead.resnet`](@ref) is the core function that builds "ResNet-like" models) +- use generic functional builders (e.g. [`Metalhead.resnet`](@ref) is the underlying function that builds "ResNet-like" models) - use multiple dispatch to add convenience constructors that wrap your functional builder When in doubt, just open a PR! We are more than happy to help review your code to help it align with the rest of the library. After adding a model, you might consider adding some pre-trained weights (see below). @@ -28,10 +28,11 @@ To add pre-trained weights for an existing model or new model, you can [open a P All Metalhead.jl model artifacts are hosted using HuggingFace. You can find the FluxML account [here](https://huggingface.co/FluxML). This [documentation from HuggingFace](https://huggingface.co/docs/hub/models) will provide you with an introduction to their ModelHub. In short, the Model Hub is a collection of Git repositories, similar to Julia packages on GitHub. This means you can [make a pull request to our HuggingFace repositories](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) to upload updated weight artifacts just like you would make a PR on GitHub to upload code. 1. Train your model or port the weights from another framework. -2. Save the model using [BSON.jl](https://github.com/JuliaIO/BSON.jl) with `BSON.@save "modelname.bson" model`. It is important that your model is saved under the key `model`. +2. Save the model using [BSON.jl](https://github.com/JuliaIO/BSON.jl) with `BSON.@save "modelname.bson" model`. It is important that your model is saved under the key `model`. Note that due to the way this +process works, to maintain compatibility with different Julia versions, the model must be saved using the LTS version of Julia (currently 1.6). 3. Compress the saved model as a tarball using `tar -cvzf modelname.tar.gz modelname.bson`. 4. Obtain the SHAs (see the [Pkg docs](https://pkgdocs.julialang.org/v1/artifacts/#Basic-Usage)). Edit the `Artifacts.toml` file in the Metalhead.jl repository and add entry for your model. You can leave the URL empty for now. -5. Open a PR on Metalhead.jl. Be sure to ping a maintainer (e.g. `@darsnack`) to let us know that you are adding a pre-trained weight. We will create a model repository on HuggingFace if it does not already exist. +5. Open a PR on Metalhead.jl. Be sure to ping a maintainer (e.g. `@darsnack` or `@theabhirath`) to let us know that you are adding a pre-trained weight. We will create a model repository on HuggingFace if it does not already exist. 6. Open a PR to the [corresponding HuggingFace repo](https://huggingface.co/FluxML). Do this by going to the "Community" tab in the HuggingFace repository. PRs and discussions are shown as the same thing in the HuggingFace web app. You can use your local Git program to make clone the repo and make PRs if you wish. Check out the [guide on PRs to HuggingFace](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) for more information. 7. Copy the download URL for the model file that you added to HuggingFace. Make sure to grab the URL for a specific commit and not for the `main` branch. 8. Update your Metalhead.jl PR by adding the URL to the Artifacts.toml. diff --git a/docs/src/howto/resnet.md b/docs/src/howto/resnet.md new file mode 100644 index 000000000..3adbe2bea --- /dev/null +++ b/docs/src/howto/resnet.md @@ -0,0 +1,21 @@ +# Using the ResNet model family in Metalhead.jl + +ResNets are one of the most common convolutional neural network (CNN) models used today. Originally proposed by He et al. in [**Deep Residual Learning for Image Recognition**](https://arxiv.org/abs/1512.03385), they use a residual structure to learn identity mappings that strengthens gradient propagation, thereby helping to prevent the vanishing gradient problem and allow the advent of truly deep neural networks as used today. + +Many variants on the original ResNet structure have since become widely used such as [Wide-ResNet](https://arxiv.org/abs/1605.07146), [ResNeXt](https://arxiv.org/abs/1611.05431v2), [SE-ResNet](https://arxiv.org/abs/1709.01507) and [Res2Net](https://www.notion.so/ResNet-user-guide-b4c09e5bb5ae41328165a3f160a104f6). Apart from suggesting modifications to the structure of the residual block, papers have also suggested modifying the stem of the network, adding newer regularisation options in the form of stochastic depth and DropBlock, and changing the downsampling path for the blocks to improve performance. + +Metalhead provides an extensible, hackable yet powerful interface for working with ResNets that provides built-in toggles for commonly used options in papers and other deep learning libraries, while also allowing the user to build custom model structures if they want very easily. + +## Pre-trained models + +Metalhead provides a variety of pretrained models in the ResNet family to allow users to get started quickly with tasks like transfer learning. Pretrained models for [`ResNet`](@ref) with depth 18, 34, 50, 101 and 152 is supported, as is [`WideResNet`](@ref) with depths 50 and 101. [`ResNeXt`](@ref) also supports some configurations of pretrained models - to know more, check out the documentation for the model. + +This is as easy as setting the `pretrain` keyword to `true` when constructing the model. For example, to load a pretrained `ResNet` with depth 50, you can do the following: + +```julia +using Metalhead + +model = ResNet(50; pretrain=true) +``` + +To check out more about using pretrained models, check out the [pretrained models guide](@ref pretrained). diff --git a/docs/src/tutorials/pretrained.md b/docs/src/tutorials/pretrained.md new file mode 100644 index 000000000..e7bb41a10 --- /dev/null +++ b/docs/src/tutorials/pretrained.md @@ -0,0 +1,74 @@ +# [Working with pre-trained models from Metalhead](@id pretrained) + +Using a model from Metalhead is as simple as selecting a model from the table of [available models](@ref API-Reference). For example, below we use the pre-trained ResNet-18 model. + +```@example 1 +using Metalhead + +model = ResNet(18; pretrain = true); +``` + +## Using pre-trained models as feature extractors + +The `backbone` and `classifier` functions do exactly what their names suggest - they are used to extract the backbone and classifier of a model respectively. For example, to extract the backbone of a pre-trained ResNet-18 model: + +```@example 1 +backbone(model); +``` + +The `backbone` function could also be useful for people looking to just use specific sections of the model for transfer learning. The function returns a `Chain` of the layers of the model, so you can easily index into it to get the layers you want. For example, to get the first five layers of a pre-trained ResNet model, +you can just write `backbone(model)[1:5]`. + +## Training + +Now, we can use this model with Flux like any other model. First, let's check the accuracy on a test image from ImageNet. + +```@example 1 +using Images + +# test image +img = Images.load(download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg")); +``` + +We'll use the popular [DataAugmentation.jl](https://github.com/lorenzoh/DataAugmentation.jl) library to crop our input image, convert it to a plain array, and normalize the pixels. + +```@example 1 +using DataAugmentation +using Flux +using Flux: onecold + +DATA_MEAN = (0.485, 0.456, 0.406) +DATA_STD = (0.229, 0.224, 0.225) + +augmentations = CenterCrop((224, 224)) |> + ImageToTensor() |> + Normalize(DATA_MEAN, DATA_STD) + +data = apply(augmentations, Image(img)) |> itemdata + +# ImageNet labels +labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")) + +println(onecold(model(Flux.unsqueeze(data, 4)), labels)) +``` + +That is fairly accurate! Below, we train the model on some randomly generated data: + +```@example 1 +using Optimisers +using Flux: onehotbatch +using Flux.Losses: logitcrossentropy + +batchsize = 1 +data = [(rand(Float32, 224, 224, 3, batchsize), onehotbatch(rand(1:1000, batchsize), 1:1000)) + for _ in 1:3] +opt = Optimisers.Adam() +state = Optimisers.setup(rule, model); # initialise this optimiser's state +for (i, (image, y)) in enumerate(data) + @info "Starting batch $i ..." + gs, _ = gradient(model, image) do m, x # calculate the gradients + logitcrossentropy(m(x), y) + end; + state, model = Optimisers.update(state, model, gs); +end +``` diff --git a/docs/src/tutorials/quickstart.md b/docs/src/tutorials/quickstart.md index a00854627..ef6a7b5cd 100644 --- a/docs/src/tutorials/quickstart.md +++ b/docs/src/tutorials/quickstart.md @@ -1,57 +1,29 @@ -# Quickstart +# A guide to getting started with Metalhead -```julia -using Flux, Metalhead -``` +Metalhead.jl is a library written in Flux.jl that is a collection of image models, layers and utilities for deep learning in computer vision. -Using a model from Metalhead is as simple as selecting a model from the table of [available models](@ref API-Reference). For example, below we use the pre-trained ResNet-18 model. -```julia -using Flux, Metalhead +## Pre-trained models -model = ResNet(18; pretrain = true) -``` +In Metalhead.jl, camel-cased functions mimicking the naming style followed in the paper such as [`ResNet`](@ref) or [`ResNeXt`](@ref) are considered the "higher" level API for models. These are the functions that end-users who do not want to experiment much with model architectures should use. These models also support the option for loading pre-trained weights from ImageNet. -Now, we can use this model with Flux like any other model. +!!! note -First, let's check the accuracy on a test image from ImageNet. -```julia -using Images + Metalhead is still under active development and thus not all models have pre-trained weights supported. While we are working on expanding the footprint of the pre-trained models, if you would like to help contribute model weights yourself, please check out the [contributing guide](@ref contributing) guide. -# test image -img = Images.load(download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg")); -``` -We'll use the popular [DataAugmentation.jl](https://github.com/lorenzoh/DataAugmentation.jl) library to crop our input image, convert it to a plain array, and normalize the pixels. -```julia -using DataAugmentation, OneHotArrays +To use a pre-trained model, just instantiate the model with the `pretrain` keyword argument set to `true`: -DATA_MEAN = (0.485, 0.456, 0.406) -DATA_STD = (0.229, 0.224, 0.225) +```julia +using Metalhead + +model = ResNet(18; pretrain = true); +``` -augmentations = CenterCrop((224, 224)) |> - ImageToTensor() |> - Normalize(DATA_MEAN, DATA_STD) -data = apply(augmentations, Image(img)) |> itemdata +Refer to the pretraining guide for more details on how to use pre-trained models. -# image net labels -labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")) +## More model configuration options -onecold(model(Flux.unsqueeze(data, 4)), labels) -``` +For users who want to use more options for model configuration, Metalhead provides a "mid-level" API for models. The model functions that are in lowercase such as [`resnet`](@ref) or [`mobilenetv3`](@ref) are the "lower" level API for models. These are the functions that end-users who want to experiment with model architectures should use. These models do not support the option for loading pre-trained weights from ImageNet out of the box. -Below, we train it on some randomly generated data. +To use any of these models, check out the docstrings for the model functions. Note that these functions typically require more configuration options to be passed in, but offer a lot more flexibility in terms of model architecture. -```julia -using OneHotArrays: onehotbatch - -batchsize = 1 -data = [(rand(Float32, 224, 224, 3, batchsize), onehotbatch(rand(1:1000, batchsize), 1:1000)) - for _ in 1:3] -opt = ADAM() -ps = Flux.params(model) -loss(x, y, m) = Flux.Losses.logitcrossentropy(m(x), y) -for (i, (x, y)) in enumerate(data) - @info "Starting batch $i ..." - gs = gradient(() -> loss(x, y, model), ps) - Flux.update!(opt, ps, gs) -end -``` +## \ No newline at end of file diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 41b394f24..26121b479 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -12,7 +12,9 @@ using Random import Functors +# Utilities include("utilities.jl") +include("core.jl") # Custom Layers include("layers/Layers.jl") @@ -25,21 +27,24 @@ include("convnets/builders/mbconv.jl") include("convnets/builders/resblocks.jl") include("convnets/builders/resnet.jl") include("convnets/builders/stages.jl") -## AlexNet and VGG +## Older CNN models include("convnets/alexnet.jl") include("convnets/vgg.jl") -## ResNets -include("convnets/resnets/core.jl") -include("convnets/resnets/res2net.jl") -include("convnets/resnets/resnet.jl") -include("convnets/resnets/resnext.jl") -include("convnets/resnets/seresnet.jl") +include("convnets/squeezenet.jl") ## Inceptions include("convnets/inceptions/googlenet.jl") include("convnets/inceptions/inceptionv3.jl") include("convnets/inceptions/inceptionv4.jl") include("convnets/inceptions/inceptionresnetv2.jl") include("convnets/inceptions/xception.jl") +## ResNets +include("convnets/resnets/core.jl") +include("convnets/resnets/res2net.jl") +include("convnets/resnets/resnet.jl") +include("convnets/resnets/resnext.jl") +include("convnets/resnets/seresnet.jl") +## DenseNet +include("convnets/densenet.jl") ## EfficientNets include("convnets/efficientnets/efficientnet.jl") include("convnets/efficientnets/efficientnetv2.jl") @@ -51,9 +56,10 @@ include("convnets/mobilenets/mnasnet.jl") ## Others include("convnets/densenet.jl") include("convnets/squeezenet.jl") -include("convnets/convnext.jl") -include("convnets/convmixer.jl") include("convnets/unet.jl") +## Hybrid models +include("convnets/hybrid/convnext.jl") +include("convnets/hybrid/convmixer.jl") # Mixers include("mixers/core.jl") @@ -67,16 +73,17 @@ include("vit-based/vit.jl") # Load pretrained weights include("pretrain.jl") -export AlexNet, - VGG, - ResNet, - WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt, - DenseNet, +# export model functions +export AlexNet, VGG, ResNet, WideResNet, ResNeXt, DenseNet, GoogLeNet, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, + SEResNet, SEResNeXt, Res2Net, Res2NeXt, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet, EfficientNet, EfficientNetv2, ConvMixer, ConvNeXt, MLPMixer, ResMLP, gMLP, ViT, UNet +# useful for feature extraction +export backbone, classifier + # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, :SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet, diff --git a/src/convnets/builders/invresmodel.jl b/src/convnets/builders/invresmodel.jl index 6faeca992..c6b928170 100644 --- a/src/convnets/builders/invresmodel.jl +++ b/src/convnets/builders/invresmodel.jl @@ -1,3 +1,25 @@ +""" + build_invresmodel(scalings::NTuple{2, Real}, + block_configs::AbstractVector{<:Tuple}; + inplanes::Integer = 32, connection = +, activation = relu, + norm_layer = BatchNorm, divisor::Integer = 8, + tail_conv::Bool = true, expanded_classifier::Bool = false, + stochastic_depth_prob = nothing, headplanes::Integer, + dropout_prob = nothing, inchannels::Integer = 3, + nclasses::Integer = 1000, kwargs...) + +Creates a generic inverted residual model structure with the specified configuration. + +# Arguments + + - `scalings`: a tuple of two numbers that specify the width and depth scaling factors. + - `block_configs`: This is a `Vector` of `Tuple`s that specifies the configuration of the + inverted residual blocks. This can take several forms: + + for `dwsep_conv_norm`, or depthwise separable convolutional blocks, the tuple + should be of the form `(dwsep_conv_norm, kernel size, output channels, stride, + number of repeats, activation function)`. For example, the following configuration + is valid: `(dwsep_conv_norm, 3, 64, 1, 1, relu6)`. +""" function build_invresmodel(scalings::NTuple{2, Real}, block_configs::AbstractVector{<:Tuple}; inplanes::Integer = 32, connection = +, activation = relu, diff --git a/src/convnets/builders/resblocks.jl b/src/convnets/builders/resblocks.jl index bdf36c0f3..b8502a9fc 100644 --- a/src/convnets/builders/resblocks.jl +++ b/src/convnets/builders/resblocks.jl @@ -42,7 +42,7 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer}; # Also get `planes_vec` needed for block `inplanes` and `planes` calculations sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats)) dbschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats)) - planes_vec = collect(planes_fn(block_repeats)) + planes_vec = planes_fn(block_repeats) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) # DropBlock, StochasticDepth both take in probabilities based on a linear scaling schedule @@ -109,7 +109,7 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; downsample_tuple = (downsample_conv, downsample_identity)) sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats)) dbschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats)) - planes_vec = collect(planes_fn(block_repeats)) + planes_vec = planes_fn(block_repeats) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) # DropBlock, StochasticDepth both take in rates based on a linear scaling schedule diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 4799e22a7..6b890833e 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -11,12 +11,10 @@ Create a Densenet bottleneck layer (and scaling factor for inner feature maps; see ref) """ function dense_bottleneck(inplanes::Integer, outplanes::Integer; expansion::Integer = 4) - inner_channels = expansion * outplanes - return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; + return SkipConnection(Chain(conv_norm((1, 1), inplanes, expansion * outplanes; revnorm = true)..., - conv_norm((3, 3), inner_channels, outplanes; pad = 1, - revnorm = true)...), - cat_channels) + conv_norm((3, 3), expansion * outplanes, outplanes; + pad = 1, revnorm = true)...), cat_channels) end """ @@ -70,8 +68,9 @@ Create a DenseNet model - `dropout_prob`: the dropout probability for the classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes """ -function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_prob = nothing, - inchannels::Integer = 3, nclasses::Integer = 1000) +function build_densenet(inplanes::Integer, growth_rates; reduction = 0.5, + dropout_prob = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] append!(layers, conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3))) @@ -89,7 +88,9 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_prob end """ - densenet(nblocks; growth_rate = 32, reduction = 0.5, nclasses::Integer = 1000) + densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32, + reduction = 0.5, dropout_prob = nothing, inchannels::Integer = 3, + nclasses::Integer = 1000) Create a DenseNet model ([reference](https://arxiv.org/abs/1608.06993)). @@ -99,13 +100,15 @@ Create a DenseNet model - `nblocks`: number of dense blocks between transitions - `growth_rate`: the output feature map growth probability of dense blocks (i.e. `k` in the ref) - `reduction`: the factor by which the number of feature maps is scaled across each transition + - `dropout_prob`: the dropout probability for the classifier head. Set to `nothing` to disable dropout + - `inchannels`: the number of input channels - `nclasses`: the number of output classes """ function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32, reduction = 0.5, dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) - return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; - reduction, dropout_prob, inchannels, nclasses) + return build_densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; + reduction, dropout_prob, inchannels, nclasses) end const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16], @@ -114,12 +117,20 @@ const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16], 201 => [6, 12, 48, 32]) """ - DenseNet(config::Integer; pretrain::Bool = false, nclasses::Integer = 1000) - DenseNet(transition_configs::NTuple{N,Integer}) + DenseNet(config::Integer; pretrain::Bool = false, growth_rate::Integer = 32, + reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) Create a DenseNet model with specified configuration. Currently supported values are (121, 161, 169, 201) ([reference](https://arxiv.org/abs/1608.06993)). -Set `pretrain = true` to load the model with pre-trained weights for ImageNet. + +# Arguments + + - `config`: the configuration of the model + - `pretrain`: whether to load the model with pre-trained weights for ImageNet. + - `growth_rate`: the output feature map growth probability of dense blocks (i.e. `k` in the ref) + - `reduction`: the factor by which the number of feature maps is scaled across each transition + - `inchannels`: the number of input channels + - `nclasses`: the number of output classes !!! warning diff --git a/src/convnets/convmixer.jl b/src/convnets/hybrid/convmixer.jl similarity index 100% rename from src/convnets/convmixer.jl rename to src/convnets/hybrid/convmixer.jl diff --git a/src/convnets/convnext.jl b/src/convnets/hybrid/convnext.jl similarity index 77% rename from src/convnets/convnext.jl rename to src/convnets/hybrid/convnext.jl index a38840d0b..e11abaea1 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/hybrid/convnext.jl @@ -22,9 +22,9 @@ function convnextblock(planes::Integer, stochastic_depth_prob = 0.0, end """ - convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer}; - stochastic_depth_prob = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3, - nclasses::Integer = 1000) + build_convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer}; + stochastic_depth_prob = 0.0, layerscale_init = 1.0f-6, + inchannels::Integer = 3, nclasses::Integer = 1000) Creates the layers for a ConvNeXt model. ([reference](https://arxiv.org/abs/2201.03545)) @@ -39,10 +39,10 @@ Creates the layers for a ConvNeXt model. - `inchannels`: number of input channels. - `nclasses`: number of output classes """ -function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer}; - stochastic_depth_prob = 0.0, layerscale_init = 1.0f-6, - inchannels::Integer = 3, - nclasses::Integer = 1000) +function build_convnext(depths::AbstractVector{<:Integer}, + planes::AbstractVector{<:Integer}; + stochastic_depth_prob = 0.0, layerscale_init = 1.0f-6, + inchannels::Integer = 3, nclasses::Integer = 1000) @assert length(depths) == length(planes) "`planes` should have exactly one value for each block" downsample_layers = [] @@ -69,10 +69,26 @@ function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:In return Chain(Chain(backbone...), classifier) end +""" + convnext(config::Symbol; stochastic_depth_prob = 0.0, layerscale_init = 1.0f-6, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Creates a ConvNeXt model. +([reference](https://arxiv.org/abs/2201.03545)) + +# Arguments + + - `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`. + - `stochastic_depth_prob`: Stochastic depth probability. + - `layerscale_init`: Initial value for [`LayerScale`](@ref) + ([reference](https://arxiv.org/abs/2103.17239)) + - `inchannels`: number of input channels. + - `nclasses`: number of output classes +""" function convnext(config::Symbol; stochastic_depth_prob = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3, nclasses::Integer = 1000) - return convnext(CONVNEXT_CONFIGS[config]...; stochastic_depth_prob, layerscale_init, - inchannels, nclasses) + return build_convnext(CONVNEXT_CONFIGS[config]...; stochastic_depth_prob, + layerscale_init, inchannels, nclasses) end # Configurations for ConvNeXt models diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index c5b2f1c7d..f45a01b50 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -307,8 +307,8 @@ ResNet-50, which has 3 blocks in the first stage, 4 blocks in the second stage, third stage and 3 blocks in the fourth stage. """ function resnet_planes(block_repeats::AbstractVector{<:Integer}) - return Iterators.flatten((64 * 2^(stage_idx - 1) for _ in 1:stages) - for (stage_idx, stages) in enumerate(block_repeats)) + return collect(Iterators.flatten((64 * 2^(stage_idx - 1) for _ in 1:stages) + for (stage_idx, stages) in enumerate(block_repeats))) end """ @@ -340,7 +340,7 @@ end imsize::Dims{2} = (256, 256), inchannels::Integer = 3, nclasses::Integer = 1000, kwargs...) -Creates a generic ResNet-like model that is used to create the higher level models like ResNet, +Creates a generic ResNet-like model that is used to create The higher-level model constructors like ResNet, Wide ResNet, ResNeXt and Res2Net. For an _even_ more generic model API, see [`Metalhead.build_resnet`](@ref). # Arguments @@ -377,10 +377,17 @@ Wide ResNet, ResNeXt and Res2Net. For an _even_ more generic model API, see [`Me - `use_conv`: Set to true to use convolutions instead of identity operations in the model. - `dropblock_prob`: `DropBlock` probability to be used in the model. Set to `nothing` to disable DropBlock. See [`Metalhead.DropBlock`](@ref) for more details. - - `stochastic_depth_prob`: `StochasticDepth` probability to be used in the model. Set to `nothing` to disable - StochasticDepth. See [`Metalhead.StochasticDepth`](@ref) for more details. + - `stochastic_depth_prob`: `StochasticDepth` probability to be used in the model. Set to `nothing` + to disable StochasticDepth. See [`Metalhead.StochasticDepth`](@ref) for more details. - `dropout_prob`: `Dropout` probability to be used in the classifier head. Set to `nothing` to disable Dropout. + - `imsize`: The size of the input (height, width). + - `inchannels`: The number of input channels. + - `nclasses`: The number of output classes. + - `kwargs`: Additional keyword arguments to be passed to the block builder (note: ignore this + argument if you are not sure what it does. To know more about how this works, check out the + section of the documentation that talks about builders in Metalhead and specifically for the + ResNet block functions). """ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity); diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 6d65708d0..c8d03d8e1 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -7,7 +7,7 @@ Creates a ResNeXt model with the specified depth, cardinality, and base width. # Arguments - - `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model. + - `depth`: one of `[50, 101, 152]`. The depth of the ResNet model. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. Supported configurations are: diff --git a/src/core.jl b/src/core.jl new file mode 100644 index 000000000..eb5999845 --- /dev/null +++ b/src/core.jl @@ -0,0 +1,19 @@ +""" + backbone(model) + +This function returns the backbone of a model that can be used for feature extraction. +A `Flux.Chain` is returned, which can be indexed/sliced into to get the desired layer(s). +Note that the model used here as input must be the "camel-cased" version of the model, +e.g. `ResNet` instead of `resnet`. +""" +backbone + +""" + classifier(model) + +This function returns the classifier head of a model. This is sometimes useful for fine-tuning +a model on a different dataset. A `Flux.Chain` is returned, which can be indexed/sliced into to +get the desired layer(s). Note that the model used here as input must be the "camel-cased" +version of the model, e.g. `ResNet` instead of `resnet`. +""" +classifier diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 423ff41ec..d395ccbc9 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -1,7 +1,7 @@ module Layers using Flux -using Flux: rng_from_array +using Flux: default_rng_value using CUDA using NNlib, NNlibCUDA using Functors diff --git a/src/layers/drop.jl b/src/layers/drop.jl index f7f5d95bd..bcf6df582 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -11,7 +11,7 @@ ChainRulesCore.@non_differentiable _dropblock_mask(rng, x, gamma, clipped_block_ # TODO add experimental `DropBlock` options from timm such as gaussian noise and # more precise `DropBlock` to deal with edges (#188) """ - dropblock([rng = rng_from_array(x)], x::AbstractArray{T, 4}, drop_block_prob, block_size, + dropblock([rng = default_rng_value(x)], x::AbstractArray{T, 4}, drop_block_prob, block_size, gamma_scale, active::Bool = true) The dropblock function. If `active` is `true`, for each input, it zeroes out continguous @@ -55,7 +55,7 @@ dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) """ DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, - rng = rng_from_array()) + rng = default_rng_value()) The `DropBlock` layer. While training, it zeroes out continguous regions of size `block_size` in the input. During inference, it simply returns the input `x`. @@ -107,7 +107,7 @@ function Flux.testmode!(m::DropBlock, mode = true) end function DropBlock(drop_block_prob = 0.1, block_size::Integer = 7, gamma_scale = 1.0, - rng = rng_from_array()) + rng = default_rng_value()) if isnothing(drop_block_prob) return identity end @@ -122,7 +122,7 @@ function Base.show(io::IO, d::DropBlock) end """ - StochasticDepth(p, mode = :row; rng = rng_from_array()) + StochasticDepth(p, mode = :row; rng = default_rng_value()) Implements Stochastic Depth. This is a `Dropout` layer from Flux that drops values with probability `p`. @@ -144,7 +144,7 @@ equivalent to `identity`. for more information on the behaviour of this argument. Custom RNGs are only supported on the CPU. """ -function StochasticDepth(p, mode = :row; rng = rng_from_array()) +function StochasticDepth(p, mode = :row; rng = default_rng_value()) if isnothing(p) return identity else diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl index 13eef8ae1..14e84f486 100644 --- a/src/layers/mbconv.jl +++ b/src/layers/mbconv.jl @@ -55,6 +55,11 @@ This is a sequence of layers: - a 1x1 convolution from `explanes => outplanes` - a (batch) normalisation layer + `activation` +!!! warning + This function does not handle the residual connection by default. The user must add + this manually to use this block as a standalone. To construct a model, check out the + builders, which handle the residual connection and other details. + First introduced in the MobileNetv2 paper. (See Fig. 3 in [reference](https://arxiv.org/abs/1801.04381v4).) @@ -113,6 +118,11 @@ This is a sequence of layers: - a 1x1 convolution from `explanes => outplanes` followed by a (batch) normalisation layer + `activation` if `inplanes != explanes` +!!! warning + This function does not handle the residual connection by default. The user must add + this manually to use this block as a standalone. To construct a model, check out the + builders, which handle the residual connection and other details. + Originally introduced by Google in [EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML](https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html). Later used in the EfficientNetv2 paper. diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 60447ddea..18518731a 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -1,6 +1,5 @@ """ - AdaptiveMeanMaxPool(connection = +, output_size::Tuple = (1, 1)) - AdaptiveMeanMaxPool(output_size::Tuple = (1, 1)) + AdaptiveMeanMaxPool([connection = +], output_size::Tuple = (1, 1)) A type of adaptive pooling layer which uses both mean and max pooling and combines them to produce a single output. Note that this is equivalent to diff --git a/src/mixers/gmlp.jl b/src/mixers/gmlp.jl index 845786686..6319bfef6 100644 --- a/src/mixers/gmlp.jl +++ b/src/mixers/gmlp.jl @@ -1,13 +1,14 @@ """ - SpatialGatingUnit(norm, proj) + SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) Creates a spatial gating unit as described in the gMLP paper. ([reference](https://arxiv.org/abs/2105.08050)) # Arguments - - `norm`: the normalisation layer to use - - `proj`: the projection layer to use + - `planes`: the number of planes in the block + - `npatches`: the number of patches of the input + - `norm_layer`: the normalisation layer to use """ struct SpatialGatingUnit{T, F} norm::T @@ -15,18 +16,6 @@ struct SpatialGatingUnit{T, F} end @functor SpatialGatingUnit -""" - SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) - -Creates a spatial gating unit as described in the gMLP paper. -([reference](https://arxiv.org/abs/2105.08050)) - -# Arguments - - - `planes`: the number of planes in the block - - `npatches`: the number of patches of the input - - `norm_layer`: the normalisation layer to use -""" function SpatialGatingUnit(planes::Integer, npatches::Integer; norm_layer = LayerNorm) gateplanes = planes รท 2 norm = norm_layer(gateplanes) @@ -42,9 +31,10 @@ function (m::SpatialGatingUnit)(x) end """ - spatial_gating_block(planes::Integer, npatches::Integer; mlp_ratio = 4.0, - norm_layer = LayerNorm, mlp_layer = gated_mlp_block, - dropout_prob = 0.0, stochastic_depth_prob = 0.0, activation = gelu) + spatialgatingblock(planes::Integer, npatches::Integer; mlp_ratio = 4.0, + norm_layer = LayerNorm, mlp_layer = gated_mlp_block, + dropout_prob = 0.0, stochastic_depth_prob = 0.0, + activation = gelu) Creates a feedforward block based on the gMLP model architecture described in the paper. ([reference](https://arxiv.org/abs/2105.08050)) @@ -60,10 +50,10 @@ Creates a feedforward block based on the gMLP model architecture described in th - `stochastic_depth_prob`: Stochastic depth probability - `activation`: the activation function to use in the MLP blocks """ -function spatial_gating_block(planes::Integer, npatches::Integer; mlp_ratio = 4.0, - norm_layer = LayerNorm, mlp_layer = gated_mlp_block, - dropout_prob = 0.0, stochastic_depth_prob = 0.0, - activation = gelu) +function spatialgatingblock(planes::Integer, npatches::Integer; mlp_ratio = 4.0, + norm_layer = LayerNorm, mlp_layer = gated_mlp_block, + dropout_prob = 0.0, stochastic_depth_prob = 0.0, + activation = gelu) channelplanes = floor(Int, mlp_ratio * planes) sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) return SkipConnection(Chain(norm_layer(planes), @@ -97,7 +87,7 @@ end function gMLP(config::Symbol; imsize::Dims{2} = (224, 224), patch_size::Dims{2} = (16, 16), pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(MIXER_CONFIGS)) - layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, patch_size, + layers = mlpmixer(spatialgatingblock, imsize; mlp_layer = gated_mlp_block, patch_size, MIXER_CONFIGS[config]..., inchannels, nclasses) if pretrain loadpretrain!(layers, string("gmlp", config)) diff --git a/src/utilities.jl b/src/utilities.jl index 316d884c6..18c322bb5 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -42,25 +42,12 @@ cat_channels(x::Tuple) = cat_channels(x...) """ swapdims(perm) -Convenience function for permuting the dimensions of an array. +Convenience function that returns a closure which permutes the dimensions of an array. `perm` is a vector or tuple specifying a permutation of the input dimensions. Equivalent to `permutedims(x, perm)`. """ swapdims(perm) = Base.Fix2(permutedims, perm) -# Utility function for pretty printing large models -function _maybe_big_show(io, model) - if isdefined(Flux, :_big_show) - if isnothing(get(io, :typeinfo, nothing)) # e.g. top level in REPL - Flux._big_show(io, model) - else - show(io, model) - end - else - show(io, model) - end -end - """ linear_scheduler(drop_prob = 0.0; start_value = 0.0, depth) linear_scheduler(drop_prob::Nothing; depth::Integer) @@ -90,3 +77,15 @@ into a single iterator. flatten_chains(m::Chain) = Iterators.flatten(flatten_chains(l) for l in m.layers) flatten_chains(m) = (m,) +# Utility function for pretty printing large models +function _maybe_big_show(io, model) + if isdefined(Flux, :_big_show) + if isnothing(get(io, :typeinfo, nothing)) # e.g. top level in REPL + Flux._big_show(io, model) + else + show(io, model) + end + else + show(io, model) + end +end