From 95e1ce05d829350a69f872977b5f8766dbc7fdb3 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 May 2023 11:49:07 +0200 Subject: [PATCH 01/10] hook new weights --- .gitignore | 2 + Artifacts.toml | 92 ++++++++++---------- Project.toml | 10 ++- README.md | 26 +++--- scripts/manage_huggingface_org.jl | 11 ++- src/Metalhead.jl | 1 + src/convnets/alexnet.jl | 5 +- src/convnets/densenet.jl | 6 +- src/convnets/efficientnets/efficientnet.jl | 5 +- src/convnets/efficientnets/efficientnetv2.jl | 5 +- src/convnets/hybrid/convmixer.jl | 5 +- src/convnets/hybrid/convnext.jl | 5 +- src/convnets/mobilenets/mnasnet.jl | 5 +- src/convnets/mobilenets/mobilenetv1.jl | 5 +- src/convnets/mobilenets/mobilenetv2.jl | 5 +- src/convnets/mobilenets/mobilenetv3.jl | 5 +- src/convnets/resnets/res2net.jl | 15 ++-- src/convnets/resnets/resnet.jl | 20 ++++- src/convnets/resnets/resnext.jl | 14 ++- src/convnets/resnets/seresnet.jl | 6 +- src/convnets/squeezenet.jl | 5 +- src/convnets/unet.jl | 7 +- src/convnets/vgg.jl | 10 ++- src/pretrain.jl | 44 +++++++--- src/vit-based/vit.jl | 6 +- test/convnets.jl | 6 +- 26 files changed, 199 insertions(+), 127 deletions(-) diff --git a/.gitignore b/.gitignore index 70fa7ce17..1a4d377b4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ .DS_Store .CondaPkg/ scripts/weights/ +scripts/Artifacts.toml + # manifests docs/Manifest.toml Manifest.toml diff --git a/Artifacts.toml b/Artifacts.toml index 340f510d3..a2013b125 100644 --- a/Artifacts.toml +++ b/Artifacts.toml @@ -39,19 +39,19 @@ lazy = true url = "https://huggingface.co/FluxML/resnet101/resolve/main/resnet101.tar.gz" [resnet101-IMAGENET1K_V1] -git-tree-sha1 = "f4fc4e12e6865f2917d305fa5374dca19993c9dc" +git-tree-sha1 = "2447af36b981f5136e7e3985e4ac97e8521e6b91" lazy = true [[resnet101-IMAGENET1K_V1.download]] - sha256 = "d2b642c78fbc345fe2dd224d096908aae60328fc318b12b4aadde6ee2413e08d" + sha256 = "5c0ddad654a26ef2cd6e567b8939a06c67044cbf331c3c81c7ee5f5bd2ad5e5e" url = "https://huggingface.co/FluxML/resnet101/resolve/main/resnet101-IMAGENET1K_V1.tar.gz" [resnet101-IMAGENET1K_V2] -git-tree-sha1 = "f69139437d659038890a9d9255cc556a4481cd9d" +git-tree-sha1 = "5ca085c5d85b8cb239bb0960ba719c3ad04d4177" lazy = true [[resnet101-IMAGENET1K_V2.download]] - sha256 = "4e8c7994e9e72963c84c257c68869e0cafd1915c8cd000aa14686932145b49ac" + sha256 = "7d3b98e2f5f9f89751b2dc93d10e34e55dfd5122938dfd16df2d8ac6071576c8" url = "https://huggingface.co/FluxML/resnet101/resolve/main/resnet101-IMAGENET1K_V2.tar.gz" [resnet152] @@ -71,11 +71,11 @@ lazy = true url = "https://huggingface.co/FluxML/resnet18/resolve/main/resnet18.tar.gz" [resnet18-IMAGENET1K_V1] -git-tree-sha1 = "2e7afe62b9497ad0aca89bc247ddb19eaf8599eb" +git-tree-sha1 = "7401801a43ad8b1a7525849bd56d0ab7a7fd30ca" lazy = true [[resnet18-IMAGENET1K_V1.download]] - sha256 = "c885025e64caeb6d4bd8cb4ec322d7ba133ef681471072dd16e6ce7e4fc13d07" + sha256 = "4036b99872d8124294e555140be0bd3dedeeb48ec5a9b9242bcca30c2fa8ddd8" url = "https://huggingface.co/FluxML/resnet18/resolve/main/resnet18-IMAGENET1K_V1.tar.gz" [resnet34] @@ -87,11 +87,11 @@ lazy = true url = "https://huggingface.co/FluxML/resnet34/resolve/main/resnet34.tar.gz" [resnet34-IMAGENET1K_V1] -git-tree-sha1 = "7b55514680b27a708be82e22dad16631b2100f67" +git-tree-sha1 = "1c5e4b20afb3f9ed37ec8c9aefec80d90ad171b9" lazy = true [[resnet34-IMAGENET1K_V1.download]] - sha256 = "93843ba3bd8fa56bf5d1655f10db04ecc3d54b7db12a912a95aa7c685a235005" + sha256 = "1871e8aba336d7aac491b8fdafc28d6c0d4cebe2e7fe79b2832e183bc3a72dd0" url = "https://huggingface.co/FluxML/resnet34/resolve/main/resnet34-IMAGENET1K_V1.tar.gz" [resnet50] @@ -103,19 +103,19 @@ lazy = true url = "https://huggingface.co/FluxML/resnet50/resolve/main/resnet50.tar.gz" [resnet50-IMAGENET1K_V1] -git-tree-sha1 = "448f9178be86210a6d1baca42acfc1d58aab4780" +git-tree-sha1 = "d53fc9a15c08b47229b05e40f171e61a90d77219" lazy = true [[resnet50-IMAGENET1K_V1.download]] - sha256 = "b9b6b91f436764ca84b75da74b9e41ac5aed718f58dd1c733b66b88c7cb2f6ff" + sha256 = "70279fdf0292f36de83ba04917b25e3a42ee20a9a36e510740ccc86058e94558" url = "https://huggingface.co/FluxML/resnet50/resolve/main/resnet50-IMAGENET1K_V1.tar.gz" [resnet50-IMAGENET1K_V2] -git-tree-sha1 = "1ffc5569bd5fcc07814aac5ab0c8ae9cbbbb28ce" +git-tree-sha1 = "053257695200a6a5dddddc048b3355ef8376705c" lazy = true [[resnet50-IMAGENET1K_V2.download]] - sha256 = "61dca257a1c4a9fc64bcfbd90b6621fb8e0b90b30db2dfc9f722d9e58fcc5a0f" + sha256 = "fc10822cfa8a43fa21fe44370808e94598ed6f97c4b1e7c2bebea59e895b3fc0" url = "https://huggingface.co/FluxML/resnet50/resolve/main/resnet50-IMAGENET1K_V2.tar.gz" [resnext101_32x8d] @@ -127,19 +127,19 @@ lazy = true url = "https://huggingface.co/FluxML/resnext101_32x8d/resolve/main/resnext101_32x8d.tar.gz" [resnext101_32x8d-IMAGENET1K_V1] -git-tree-sha1 = "793f0910bf6aa03a7ee10c634100cff9766be4b0" +git-tree-sha1 = "22b6e6abc7a2d6d7044ab55f8e770f1af504b50e" lazy = true [[resnext101_32x8d-IMAGENET1K_V1.download]] - sha256 = "61318406addc929f31d29602316d28aa9c6af2bd533db7f85b166e938588bf9e" + sha256 = "779b532780b1f81d290ff5064873cf0aa75285cf3fc30f533b8c4695139348b7" url = "https://huggingface.co/FluxML/resnext101_32x8d/resolve/main/resnext101_32x8d-IMAGENET1K_V1.tar.gz" [resnext101_32x8d-IMAGENET1K_V2] -git-tree-sha1 = "5c5d7b61c18ba6c5a924178dcfb4f7784583d170" +git-tree-sha1 = "61ce835d27a72b53d074deb241cc198dbc4efdcf" lazy = true [[resnext101_32x8d-IMAGENET1K_V2.download]] - sha256 = "4fb0af43466225f9557aa3dd8d61b14d7ced6f723750ba3f294d8250ba38525c" + sha256 = "39aa9153aff5ed0b96072be643f331964c81190654a8694e4b01f512e6590ab1" url = "https://huggingface.co/FluxML/resnext101_32x8d/resolve/main/resnext101_32x8d-IMAGENET1K_V2.tar.gz" [resnext101_64x4d] @@ -151,11 +151,11 @@ lazy = true url = "https://huggingface.co/FluxML/resnext101_64x4d/resolve/main/resnext101_64x4d.tar.gz" [resnext101_64x4d-IMAGENET1K_V1] -git-tree-sha1 = "bbdf9a5dc2775ec002a42041a75a1f99292698e9" +git-tree-sha1 = "950ad2fe08acb3c6e0a130d13ef59fd573d71d7f" lazy = true [[resnext101_64x4d-IMAGENET1K_V1.download]] - sha256 = "7c7dd823f9155b8d806378490880900e6386587b04341f1c65e39dff38fc1c13" + sha256 = "660a1559f2849df27407e65d6d50e694dbc6f3e32b90fd4eb2854b8bc647c2c8" url = "https://huggingface.co/FluxML/resnext101_64x4d/resolve/main/resnext101_64x4d-IMAGENET1K_V1.tar.gz" [resnext50_32x4d] @@ -167,19 +167,19 @@ lazy = true url = "https://huggingface.co/FluxML/resnext50_32x4d/resolve/main/resnext50_32x4d.tar.gz" [resnext50_32x4d-IMAGENET1K_V1] -git-tree-sha1 = "ab496fa9aff44f2b10a74e70d24de19a57c76c53" +git-tree-sha1 = "a44be12b302a21e0f8db4155b054e2a4b9df79c2" lazy = true [[resnext50_32x4d-IMAGENET1K_V1.download]] - sha256 = "de7d37ebe690ad8fca7023fb479684d5f998d2813efed3be875cde8f471d3967" + sha256 = "6647ea4264da1e09fa5b459aa5b1cf5db11e0427ae254919ed0589d18b4ab355" url = "https://huggingface.co/FluxML/resnext50_32x4d/resolve/main/resnext50_32x4d-IMAGENET1K_V1.tar.gz" [resnext50_32x4d-IMAGENET1K_V2] -git-tree-sha1 = "65b0924fe15d8ca03def45ae493a59498f435f2a" +git-tree-sha1 = "2686dc6dbe6fa19ae154fa6779596bb0b76101cd" lazy = true [[resnext50_32x4d-IMAGENET1K_V2.download]] - sha256 = "456f151be9ccf519daf4a5381fe5bb160da085693be3415833c580f2ef91aba4" + sha256 = "1d604c8fddd9a3f0af059ea4a616a3f4e04e27e54141a63a9ab6252d0799239d" url = "https://huggingface.co/FluxML/resnext50_32x4d/resolve/main/resnext50_32x4d-IMAGENET1K_V2.tar.gz" [squeezenet] @@ -199,11 +199,11 @@ lazy = true url = "https://huggingface.co/FluxML/vgg11/resolve/main/vgg11.tar.gz" [vgg11-IMAGENET1K_V1] -git-tree-sha1 = "4370fedbcffd7d4bc53560661093fa8b109d9d79" +git-tree-sha1 = "73fdc2303007b062ec632fe870ee7ca73262c3f5" lazy = true [[vgg11-IMAGENET1K_V1.download]] - sha256 = "4c7fe295439a85626930f7ff02a80bbc546d6664946b2f240c4df4dcd46f0d98" + sha256 = "b0159696a0878cf25b2c4310e41330241c174b407fbb6a7d0398fddef4a8ba93" url = "https://huggingface.co/FluxML/vgg11/resolve/main/vgg11-IMAGENET1K_V1.tar.gz" [vgg13] @@ -215,11 +215,11 @@ lazy = true url = "https://huggingface.co/FluxML/vgg13/resolve/main/vgg13.tar.gz" [vgg13-IMAGENET1K_V1] -git-tree-sha1 = "3d4bb6eb7776baca567b423ddb539b2d307ac01f" +git-tree-sha1 = "9ccb470a3ec9e56352e5c7f7944c13bedb825945" lazy = true [[vgg13-IMAGENET1K_V1.download]] - sha256 = "da731e52b50dbed4f6ecdd762cfc938516d2c52fd04ccfa2300ce59f90b99aff" + sha256 = "8109cccc14b31cb79846a32f002dbb7ad9322f336daa882aeab2c825def8fd97" url = "https://huggingface.co/FluxML/vgg13/resolve/main/vgg13-IMAGENET1K_V1.tar.gz" [vgg16] @@ -231,11 +231,11 @@ lazy = true url = "https://huggingface.co/FluxML/vgg16/resolve/main/vgg16.tar.gz" [vgg16-IMAGENET1K_V1] -git-tree-sha1 = "33669954aad1159f2be55608cc100ee55af9048d" +git-tree-sha1 = "0af8d0d61097020bec4438892af9efe658ae239c" lazy = true [[vgg16-IMAGENET1K_V1.download]] - sha256 = "2ab0b0d3ef5944db5fed51481fcb83d979291e718b12985089ed5305adcb02b1" + sha256 = "3c1d99577345def358f25cabb747aff0f42f660cd998456ec419f7a82d426779" url = "https://huggingface.co/FluxML/vgg16/resolve/main/vgg16-IMAGENET1K_V1.tar.gz" [vgg19] @@ -247,43 +247,43 @@ lazy = true url = "https://huggingface.co/FluxML/vgg19/resolve/main/vgg19.tar.gz" [vgg19-IMAGENET1K_V1] -git-tree-sha1 = "7205d226bae6a377f5c985a48dc0da377b5db3a1" +git-tree-sha1 = "85f19750c0caa35c47e3e376b72424cd324663cb" lazy = true [[vgg19-IMAGENET1K_V1.download]] - sha256 = "1941d9180d570e2d855b7c86f10686179a3b77e08bac79f914e029eec406445c" + sha256 = "ecbfbd59a8b9545e8d5116f968f4e409a0c42264220376e346920ca110a61eb1" url = "https://huggingface.co/FluxML/vgg19/resolve/main/vgg19-IMAGENET1K_V1.tar.gz" [vit_b_16-IMAGENET1K_V1] -git-tree-sha1 = "60524d3d87fcf741c4dc9f22b5a4cd265e9158a0" +git-tree-sha1 = "f8e2e75cf329cd24da4f092dfb76036f4e24402b" lazy = true [[vit_b_16-IMAGENET1K_V1.download]] - sha256 = "9f208083e1b4a360098516c694fc13c02f22fc7c62b751ed96caedcede6c5435" + sha256 = "8543956979591e83aa98d6cd7d2370cc3325fb70c79913d27a1187c331c101e6" url = "https://huggingface.co/FluxML/vit_b_16/resolve/main/vit_b_16-IMAGENET1K_V1.tar.gz" [vit_b_32-IMAGENET1K_V1] -git-tree-sha1 = "d5c51ea70396046f2734f599066ec5bb7b2195a1" +git-tree-sha1 = "5d638ff94c5f763025ad3118adf1101faa2ec86e" lazy = true [[vit_b_32-IMAGENET1K_V1.download]] - sha256 = "fd32efbeddbcaeac7e1143b26f3524d93699b149e759c4045114427127520fdf" + sha256 = "694d9d4e5d012f31c4108af3b35cb450bc4adb40a4a669788d7bc569450c9620" url = "https://huggingface.co/FluxML/vit_b_32/resolve/main/vit_b_32-IMAGENET1K_V1.tar.gz" [vit_l_16-IMAGENET1K_V1] -git-tree-sha1 = "1eb2d8d4a103306d988432715d6af9e32b10e735" +git-tree-sha1 = "6246a41e497334a1f0aebfd9364f4fc4d6dac98f" lazy = true [[vit_l_16-IMAGENET1K_V1.download]] - sha256 = "8ec0a15947519462df603a7a85d79541d6e284c00edb03038ff50b64f45add2b" + sha256 = "6ab996ac74e1c9677153058d7f24ee46c2e1fccc27a2a1dd1679b60416cf0275" url = "https://huggingface.co/FluxML/vit_l_16/resolve/main/vit_l_16-IMAGENET1K_V1.tar.gz" [vit_l_32-IMAGENET1K_V1] -git-tree-sha1 = "ca247073538543c8aa94ececa4f661014459a976" +git-tree-sha1 = "edb9e37ecc210446824e1590b47a9d8fea8a71ce" lazy = true [[vit_l_32-IMAGENET1K_V1.download]] - sha256 = "648d58e2ffefb32e4f9ee5f809638fb60f65df9faf36036d63939e94173ef511" + sha256 = "d552a4489e106ccd1cce53e2b8f0151d9d791b5963527f63d7e979abded2d473" url = "https://huggingface.co/FluxML/vit_l_32/resolve/main/vit_l_32-IMAGENET1K_V1.tar.gz" [wideresnet101] @@ -295,19 +295,19 @@ lazy = true url = "https://huggingface.co/FluxML/wideresnet101/resolve/main/wideresnet101.tar.gz" [wideresnet101-IMAGENET1K_V1] -git-tree-sha1 = "96a40a4471f3fc9109e19ed5efe80eee0bd99069" +git-tree-sha1 = "3547e50e95e2323de760706569342cb7e7ca9730" lazy = true [[wideresnet101-IMAGENET1K_V1.download]] - sha256 = "e44b248ec8e3c251069a635d8a0bc7072f715bb0ec69bd2f797357f22270b83e" + sha256 = "7db16af88b0fb6b10dd50c88380bac9d0a94b88a127099ca5630c74bafa462cb" url = "https://huggingface.co/FluxML/wideresnet101/resolve/main/wideresnet101-IMAGENET1K_V1.tar.gz" [wideresnet101-IMAGENET1K_V2] -git-tree-sha1 = "bb4d432f5499dbd822b2f79a8b54ec0e423792d5" +git-tree-sha1 = "eefc23da10b3efb737f6bc8af96463ee1f850acd" lazy = true [[wideresnet101-IMAGENET1K_V2.download]] - sha256 = "0bee2a5fd8ceb2f47a8d1d18cd29347afd2280a85bf823afac662c25bd0969da" + sha256 = "93fe311bbdc16d5e21d4e8a513de9b617b872b6ba2c10344903275badef8a0fa" url = "https://huggingface.co/FluxML/wideresnet101/resolve/main/wideresnet101-IMAGENET1K_V2.tar.gz" [wideresnet50] @@ -319,17 +319,17 @@ lazy = true url = "https://huggingface.co/FluxML/wideresnet50/resolve/main/wideresnet50.tar.gz" [wideresnet50-IMAGENET1K_V1] -git-tree-sha1 = "56eaedf7a1febc5b47c43da3ea55e76d03b80a30" +git-tree-sha1 = "7b60fe69709f6b4ec419cf84e5626b97b2511681" lazy = true [[wideresnet50-IMAGENET1K_V1.download]] - sha256 = "bb927f66a35d00b61a81abaa679d0503caff83651cbd07d6ad0fe34aefcd3661" + sha256 = "3ad68fc8a8f9fb7d9f3cfcd892ceda4f2f123d0a78ff521b305f53d3b39773aa" url = "https://huggingface.co/FluxML/wideresnet50/resolve/main/wideresnet50-IMAGENET1K_V1.tar.gz" [wideresnet50-IMAGENET1K_V2] -git-tree-sha1 = "f0c535e705a8289d30cb311f3808d9a169219520" +git-tree-sha1 = "49b624a932935c426b75dcf3b2446b6d1051de22" lazy = true [[wideresnet50-IMAGENET1K_V2.download]] - sha256 = "684854df432e7c97aa6729dc9d4686130bdb20fa44d4dbbbb0a709ec8050ba06" + sha256 = "b273f62a28f54e7d94f15aba4fbbf219fb313b3613afef07dac672808fff7232" url = "https://huggingface.co/FluxML/wideresnet50/resolve/main/wideresnet50-IMAGENET1K_V2.tar.gz" diff --git a/Project.toml b/Project.toml index 59192d83b..f3181ee87 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" @@ -19,11 +20,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] BSON = "0.3.2" -CUDA = "3, 4" +CUDA = "4" ChainRulesCore = "1" -Flux = "0.13" -Functors = "0.2, 0.3, 0.4" -MLUtils = "0.2.10, 0.3, 0.4" +Flux = "0.13.16" +Functors = "0.4" +JLD2 = "0.4" +MLUtils = "0.4" NNlib = "0.8" NNlibCUDA = "0.2" PartialFunctions = "1" diff --git a/README.md b/README.md index a5a8d7ded..0467ee8cf 100644 --- a/README.md +++ b/README.md @@ -14,28 +14,28 @@ julia> ]add Metalhead ## Available models -| Model Name | Function | Pre-trained? | +| Model Name | Constructor | Pre-trained? | |:-------------------------------------------------|:-----------------------------------------------------------------------------------------------|:------------:| -| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.VGG) | Y (w/o BN) | -| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResNet) | Y | -| [WideResNet](https://arxiv.org/abs/1605.07146) | [`WideResNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.WideResNet) | Y | +| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ConvNeXt) | N | +| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ConvMixer) | N | +| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.DenseNet) | N | +| [EfficientNet](https://arxiv.org/abs/1905.11946) | [`EfficientNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.EfficientNet) | N | +| [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.gMLP) | N | | [GoogLeNet](https://arxiv.org/abs/1409.4842) | [`GoogLeNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.GoogLeNet) | N | | [Inception-v3](https://arxiv.org/abs/1512.00567) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.Inceptionv3) | N | | [Inception-v4](https://arxiv.org/abs/1602.07261) | [`Inceptionv4`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.Inceptionv4) | N | | [InceptionResNet-v2](https://arxiv.org/abs/1602.07261) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.InceptionResNetv2) | N | -| [SqueezeNet](https://arxiv.org/abs/1602.07360) | [`SqueezeNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.SqueezeNet) | Y | -| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.DenseNet) | N | -| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResNeXt) | Y | +| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MLPMixer) | N | | [MobileNetv1](https://arxiv.org/abs/1704.04861) | [`MobileNetv1`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv1) | N | | [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv2) | N | | [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv3) | N | -| [EfficientNet](https://arxiv.org/abs/1905.11946) | [`EfficientNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.EfficientNet) | N | -| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MLPMixer) | N | | [ResMLP](https://arxiv.org/abs/2105.03404) | [`ResMLP`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResMLP) | N | -| [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.gMLP) | N | -| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ViT) | N | -| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ConvNeXt) | N | -| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ConvMixer) | N | +| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResNet) | Y | +| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResNeXt) | Y | +| [SqueezeNet](https://arxiv.org/abs/1602.07360) | [`SqueezeNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.SqueezeNet) | Y | +| [WideResNet](https://arxiv.org/abs/1605.07146) | [`WideResNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.WideResNet) | Y | +| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.VGG) | Y | +| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ViT) | Y | To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhead.jl/dev/contributing/). diff --git a/scripts/manage_huggingface_org.jl b/scripts/manage_huggingface_org.jl index 371c21273..f795851e4 100644 --- a/scripts/manage_huggingface_org.jl +++ b/scripts/manage_huggingface_org.jl @@ -44,7 +44,7 @@ function create_model_artifacts(; force=false) weight_name = basename(weight_folder) artifact_path = joinpath(model_folder, "$(weight_name).tar.gz") if !isfile(artifact_path) || force - run(`tar -czvf $(artifact_path) $(weight_folder)`) + run(`tar -czvf $(artifact_path) -C $(weight_folder) .`) end push!(artifacts, (model_name, weight_name, artifact_path)) end @@ -74,10 +74,15 @@ end ### Create artifacts and upload to HuggingFace repos ############ # hfhub.login(ENV["HUGGINGFACE_TOKEN"]) # model_artifacts = create_model_artifacts(force=false) -# # model_artifacts = filter(x -> startswith(x[1], "wideresnet"), model_artifacts) +# model_artifacts = filter(model_artifacts) do x +# !startswith(x[1], "resnet") && !startswith(x[1], "resnext") +# end # upload_artifacts_to_hf(model_artifacts) ### Generate Artifacts.toml from HuggingFace repos ############# fluxml_model_repos = list_fluxml_models() -# # fluxml_model_repos = filter(x -> true, fluxml_model_repos) +# fluxml_model_repos = filter(fluxml_model_repos) do repo +# name = split(repo[:id], "/")[2] +# startswith(name, "resnet") +# end generate_artifacts_toml(fluxml_model_repos) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 8dedbd656..d7db8d4c3 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -4,6 +4,7 @@ using Flux using Flux: Zygote, outputsize using Functors using BSON +using JLD2 using Artifacts, LazyArtifacts using Statistics using MLUtils diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index fba6749e4..ad4324728 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -55,10 +55,11 @@ end function AlexNet(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = alexnet(; inchannels, nclasses) + model = AlexNet(layers) if pretrain - loadpretrain!(layers, "AlexNet") + loadpretrain!(model, "alexnet") end - return AlexNet(layers) + return model end (m::AlexNet)(x) = m.layers(x) diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 6ffcd42c2..de326decc 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -148,10 +148,12 @@ function DenseNet(config::Int; pretrain::Bool = false, growth_rate::Int = 32, _checkconfig(config, keys(DENSENET_CONFIGS)) layers = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, nclasses) + model = DenseNet(layers) if pretrain - loadpretrain!(layers, string("densenet", config)) + artifact_name = string("densenet", config) + loadpretrain!(model, artifact_name) end - return DenseNet(layers) + return model end (m::DenseNet)(x) = m.layers(x) diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index 2657a3884..3948048f7 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -86,10 +86,11 @@ end function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = efficientnet(config; inchannels, nclasses) + model = EfficientNet(layers) if pretrain - loadpretrain!(layers, string("efficientnet-", config)) + loadpretrain!(model, string("efficientnet_", config)) end - return EfficientNet(layers) + return model end (m::EfficientNet)(x) = m.layers(x) diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index dab159e68..9c0bb3c1d 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -89,10 +89,11 @@ end function EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = efficientnetv2(config; inchannels, nclasses) + model = EfficientNetv2(layers) if pretrain - loadpretrain!(layers, string("efficientnetv2-", config)) + loadpretrain!(model, string("efficientnet_v2_", config)) end - return EfficientNetv2(layers) + return model end (m::EfficientNetv2)(x) = m.layers(x) diff --git a/src/convnets/hybrid/convmixer.jl b/src/convnets/hybrid/convmixer.jl index 83d8d4638..fd226d2af 100644 --- a/src/convnets/hybrid/convmixer.jl +++ b/src/convnets/hybrid/convmixer.jl @@ -70,10 +70,11 @@ function ConvMixer(config::Symbol; pretrain::Bool = false, inchannels::Integer = _checkconfig(config, keys(CONVMIXER_CONFIGS)) layers = convmixer(CONVMIXER_CONFIGS[config][1]...; CONVMIXER_CONFIGS[config][2]..., inchannels, nclasses) + model = ConvMixer(layers) if pretrain - loadpretrain!(layers, "convmixer$config") + loadpretrain!(model, "convmixer_$config") end - return ConvMixer(layers) + return model end (m::ConvMixer)(x) = m.layers(x) diff --git a/src/convnets/hybrid/convnext.jl b/src/convnets/hybrid/convnext.jl index e11abaea1..9b9b3f326 100644 --- a/src/convnets/hybrid/convnext.jl +++ b/src/convnets/hybrid/convnext.jl @@ -127,10 +127,11 @@ function ConvNeXt(config::Symbol; pretrain::Bool = false, inchannels::Integer = nclasses::Integer = 1000) _checkconfig(config, keys(CONVNEXT_CONFIGS)) layers = convnext(config; inchannels, nclasses) + model = ConvNeXt(layers) if pretrain - layers = loadpretrain!(layers, "convnext_$config") + loadpretrain!(layers, "convnext_$config") end - return ConvNeXt(layers) + return model end (m::ConvNeXt)(x) = m.layers(x) diff --git a/src/convnets/mobilenets/mnasnet.jl b/src/convnets/mobilenets/mnasnet.jl index 98cd9d759..9950a069c 100644 --- a/src/convnets/mobilenets/mnasnet.jl +++ b/src/convnets/mobilenets/mnasnet.jl @@ -100,10 +100,11 @@ function MNASNet(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(MNASNET_CONFIGS)) layers = mnasnet(config; width_mult, inchannels, nclasses) + model = MNASNet(layers) if pretrain - loadpretrain!(layers, "mnasnet$(width_mult)") + loadpretrain!(model, "mnasnet$(width_mult)") end - return MNASNet(layers) + return model end (m::MNASNet)(x) = m.layers(x) diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index ca17dc4ac..0d0aef404 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -65,10 +65,11 @@ end function MobileNetv1(width_mult::Real = 1; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = mobilenetv1(width_mult; inchannels, nclasses) + model = MobileNetv1(layers) if pretrain - loadpretrain!(layers, string("MobileNetv1")) + loadpretrain!(model, "mobilenet_v1") end - return MobileNetv1(layers) + return model end (m::MobileNetv1)(x) = m.layers(x) diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index 6d0973130..ed5e6f40c 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -71,10 +71,11 @@ end function MobileNetv2(width_mult::Real = 1; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = mobilenetv2(width_mult; inchannels, nclasses) + model = MobileNetv1(layers) if pretrain - loadpretrain!(layers, string("MobileNetv2")) + loadpretrain!(model, "mobilenet_v2") end - return MobileNetv2(layers) + return model end (m::MobileNetv2)(x) = m.layers(x) diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index 07e4501eb..0027125a6 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -91,10 +91,11 @@ end function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = mobilenetv3(config; width_mult, inchannels, nclasses) + model = MobileNetv1(layers) if pretrain - loadpretrain!(layers, string("MobileNetv3", config)) + loadpretrain!(model, "mobilenet_v3") end - return MobileNetv3(layers) + return model end (m::MobileNetv3)(x) = m.layers(x) diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index 33b9fb961..df7f2c98d 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -33,10 +33,13 @@ function Res2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4, _checkconfig(depth, keys(LRESNET_CONFIGS)) layers = resnet(bottle2neck, LRESNET_CONFIGS[depth][2]; base_width, scale, inchannels, nclasses) + + model = Res2Net(layers) if pretrain - loadpretrain!(layers, string("Res2Net", depth, "_", base_width, "x", scale)) + artifact_name = string("res2net", depth, "_", base_width, "x", scale) + loadpretrain!(layers, artifact_name) end - return Res2Net(layers) + return model end (m::Res2Net)(x) = m.layers(x) @@ -80,12 +83,12 @@ function Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4, _checkconfig(depth, keys(LRESNET_CONFIGS)) layers = resnet(bottle2neck, LRESNET_CONFIGS[depth][2]; base_width, scale, cardinality, inchannels, nclasses) + model = Res2NeXt(layers) if pretrain - loadpretrain!(layers, - string("Res2NeXt", depth, "_", base_width, "x", cardinality, - "x", scale)) + artifact_name = string("res2next", depth, "_", base_width, "x", scale, "x", cardinality) + loadpretrain!(layers, artifact_name) end - return Res2NeXt(layers) + return model end (m::Res2NeXt)(x) = m.layers(x) diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index 6fbc5436a..5010164aa 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -22,10 +22,17 @@ function ResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(depth, keys(RESNET_CONFIGS)) layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses) + model = ResNet(layers) if pretrain - loadpretrain!(layers, string("resnet", depth)) + artifact_name = "resnet$(depth)" + if depth ∈ [18, 34] + artifact_name *= "-IMAGENET1K_V1" + elseif depth ∈ [50, 101] + artifact_name *= "-IMAGENET1K_V2" + end + loadpretrain!(model, artifact_name) end - return ResNet(layers) + return model end (m::ResNet)(x) = m.layers(x) @@ -59,10 +66,15 @@ function WideResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer nclasses::Integer = 1000) _checkconfig(depth, keys(LRESNET_CONFIGS)) layers = resnet(LRESNET_CONFIGS[depth]...; base_width = 128, inchannels, nclasses) + model = WideResNet(layers) if pretrain - loadpretrain!(layers, string("wideresnet", depth)) + artifact_name = "wideresnet$(depth)" + if depth ∈ [50, 101] + artifact_name *= "-IMAGENET1K_V2" + end + loadpretrain!(layers, artifact_name) end - return WideResNet(layers) + return model end (m::WideResNet)(x) = m.layers(x) diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index c8d03d8e1..f30f69c7c 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -32,11 +32,19 @@ function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = _checkconfig(depth, keys(LRESNET_CONFIGS)) layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width) + model = ResNeXt(layers) if pretrain - loadpretrain!(layers, - string("resnext", depth, "_", cardinality, "x", base_width, "d")) + artifact_name = string("resnext", depth, "_", cardinality, "x", base_width, "d") + if depth == 50 && cardinality == 32 && base_width == 4 + artifact_name *= "-IMAGENET1K_V2" + elseif depth == 101 && cardinality == 32 && base_width == 8 + artifact_name *= "-IMAGENET1K_V2" + elseif depth == 101 && cardinality == 64 && base_width == 4 + artifact_name *= "-IMAGENET1K_V2" + end + loadpretrain!(model, artifact_name) end - return ResNeXt(layers) + return model end (m::ResNeXt)(x) = m.layers(x) diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 44e32083d..4be519720 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -27,10 +27,12 @@ function SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = _checkconfig(depth, keys(RESNET_CONFIGS)) layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, attn_fn = squeeze_excite) + model = SEResNet(layers) if pretrain - loadpretrain!(layers, string("seresnet", depth)) + artifact_name = "seresnet$(depth)" + loadpretrain!(layers, artifact_name) end - return SEResNet(layers) + return model end (m::SEResNet)(x) = m.layers(x) diff --git a/src/convnets/squeezenet.jl b/src/convnets/squeezenet.jl index 1e90dd55c..8629daadc 100644 --- a/src/convnets/squeezenet.jl +++ b/src/convnets/squeezenet.jl @@ -73,10 +73,11 @@ end function SqueezeNet(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = squeezenet(; inchannels, nclasses) + model = SqueezeNet(layers) if pretrain - loadpretrain!(layers, "squeezenet") + loadpretrain!(model, "squeezenet") end - return SqueezeNet(layers) + return model end (m::SqueezeNet)(x) = m.layers(x) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 91200c0ce..75f52267a 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -115,11 +115,12 @@ end function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes) - + model = UNet(layers) if pretrain - loadpretrain!(layers, string("UNet")) + artifact_name = "UNet" + loadpretrain!(model, artifact_name) end - return UNet(layers) + return model end (m::UNet)(x::AbstractArray) = m.layers(x) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 35d01e0b5..29336021a 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -163,10 +163,12 @@ function VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false, _checkconfig(depth, keys(VGG_CONFIGS)) model = VGG((224, 224); config = VGG_CONV_CONFIGS[VGG_CONFIGS[depth]], batchnorm, inchannels, nclasses) - if pretrain && !batchnorm - loadpretrain!(model, string("vgg", depth)) - elseif pretrain - loadpretrain!(model, "vgg$(depth)-bn)") + if pretrain + artifact_name = string("vgg", depth) + if batchnorm + artifact_name *= "-bn" + end + loadpretrain!(model, artifact_name) end return model end diff --git a/src/pretrain.jl b/src/pretrain.jl index 3ca4c153f..2d2c637f4 100644 --- a/src/pretrain.jl +++ b/src/pretrain.jl @@ -1,26 +1,46 @@ """ - weights(model) + loadweights(artifact_name) Load the pre-trained weights for `model` using the stored artifacts. """ -function weights(model) - path = try - joinpath(@artifact_str(model), "$model.bson") +function loadweights(artifact_name) + artifact_dir = try + @artifact_str(artifact_name) catch e - throw(ArgumentError("No pre-trained weights available for $model.")) + throw(ArgumentError("No pre-trained weights available for $artifact_name.")) end - - artifact = BSON.load(path, @__MODULE__) - if haskey(artifact, :model) - return artifact[:model] + file_name = readdir(artifact_dir)[1] + file_path = joinpath(artifact_dir, file_name) + + if endswith(file_name, ".bson") + artifact = BSON.load(file_path, @__MODULE__) + if haskey(artifact, :model_state) + return artifact[:model_state] + elseif haskey(artifact, :model) + return artifact[:model] + else + throw(ErrorException("Found weight artifact for $artifact_name but the weights are not saved under the key :model_state or :model.")) + end + elseif endswith(file_path, ".jld2") + artifact = JLD2.load(file_path) + if haskey(artifact, "model_state") + return artifact["model_state"] + elseif haskey(artifact, "model") + return artifact["model"] + else + throw(ErrorException("Found weight artifact for $artifact_name but the weights are not saved under the key \"model_state\" or \"model\".")) + end else - throw(ErrorException("Found weight artifact for $model but the weights are not saved under the key :model.")) + throw(ErrorException("Found weight artifact for $artifact_name but only jld2 and bson serialization format are supported.")) end end """ - loadpretrain!(model, name) + loadpretrain!(model, artifact_name) Load the pre-trained weight artifacts matching `.bson` into `model`. """ -loadpretrain!(model, name) = Flux.loadmodel!(model, weights(name)) +function loadpretrain!(model, artifact_name) + m = loadweights(artifact_name) + Flux.loadmodel!(model, m) +end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 92ff9a406..7eb26b10b 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -105,10 +105,12 @@ function ViT(config::Symbol; imsize::Dims{2} = (224, 224), patch_size::Dims{2} = pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(VIT_CONFIGS)) layers = vit(imsize; inchannels, patch_size, nclasses, VIT_CONFIGS[config]...) + model = ViT(layers) if pretrain - loadpretrain!(layers, string("vit", config)) + artifact_name = "vit_$(string(config)[1])_$(patch_size[1])-IMAGENET1K_V1" + loadpretrain!(model, artifact_name) end - return ViT(layers) + return model end (m::ViT)(x) = m.layers(x) diff --git a/test/convnets.jl b/test/convnets.jl index 2f220b634..5a1392206 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -26,7 +26,7 @@ end m = ResNet(sz) @test size(m(x_224)) == (1000, 1) if (ResNet, sz) in PRETRAINED_MODELS - @test_broken acctest(ResNet(sz, pretrain = true)) + @test acctest(ResNet(sz, pretrain = true)) else @test_throws ArgumentError ResNet(sz, pretrain = true) end @@ -63,7 +63,7 @@ end @test gradtest(m, x_224) _gc() if (WideResNet, sz) in PRETRAINED_MODELS - @test_broken acctest(WideResNet(sz, pretrain = true)) + @test acctest(WideResNet(sz, pretrain = true)) else @test_throws ArgumentError WideResNet(sz, pretrain = true) end @@ -78,7 +78,7 @@ end m = ResNeXt(depth; cardinality, base_width) @test size(m(x_224)) == (1000, 1) if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS - @test_broken acctest(ResNeXt(depth; cardinality, base_width, pretrain = true)) + @test acctest(ResNeXt(depth; cardinality, base_width, pretrain = true)) else @test_throws ArgumentError ResNeXt(depth; cardinality, base_width, pretrain = true) end From 9c65110c1a12036585041bdd03363917b2a155c2 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 May 2023 11:49:32 +0200 Subject: [PATCH 02/10] readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 0467ee8cf..85975de75 100644 --- a/README.md +++ b/README.md @@ -42,3 +42,4 @@ To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhea ## Getting Started You can find the Metalhead.jl getting started guide [here](https://fluxml.ai/Metalhead.jl/dev/tutorials/quickstart/). +; \ No newline at end of file From 1812c22a252ee13491309e4ae17462c5c3dba4a2 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 May 2023 12:28:58 +0200 Subject: [PATCH 03/10] readme fixes --- docs/make.jl | 28 ++++++++++------- docs/src/api/efficientnet.md | 10 ++++++ docs/src/index.md | 60 ++++++++++++++++++------------------ 3 files changed, 57 insertions(+), 41 deletions(-) create mode 100644 docs/src/api/efficientnet.md diff --git a/docs/make.jl b/docs/make.jl index 08d667b88..7959b0428 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,6 +2,11 @@ using Documenter, Metalhead, Artifacts, LazyArtifacts, Images, DataAugmentation, DocMeta.setdocmeta!(Metalhead, :DocTestSetup, :(using Metalhead); recursive = true) +# copy readme into index.md +open(joinpath(@__DIR__, "src", "index.md"), "w") do io + write(io, read(joinpath(@__DIR__, "..", "README.md"), String)) +end + makedocs(; modules = [Metalhead, Artifacts, LazyArtifacts, Images, DataAugmentation, Flux], sitename = "Metalhead.jl", doctest = false, @@ -10,27 +15,28 @@ makedocs(; modules = [Metalhead, Artifacts, LazyArtifacts, Images, DataAugmentat "tutorials/quickstart.md", "tutorials/pretrained.md", ], + "Guides" => [ + "howto/resnet.md", + ], + "Contributing to Metalhead" => "contributing.md", "API reference" => [ "Convolutional Neural Networks" => [ - "api/others.md", - "api/inception.md", "api/resnet.md", "api/densenet.md", + "api/efficientnet.md", + "api/inception.md", "api/hybrid.md", - "api/layers.md", - ], + "api/others.md", + ], "Mixers" => [ "api/mixers.md", - ], + ], "Vision Transformers" => [ "api/vit.md", - ], - "api/utilities.md" + ], + "Layers" => "api/layers.md", + "Utilities" => "api/utilities.md", ], - "How To" => [ - "howto/resnet.md", - ], - "Contributing to Metalhead" => "contributing.md", ], format = Documenter.HTML(; canonical = "https://fluxml.ai/Metalhead.jl/stable/", # analytics = "UA-36890222-9", diff --git a/docs/src/api/efficientnet.md b/docs/src/api/efficientnet.md new file mode 100644 index 000000000..ac88d9656 --- /dev/null +++ b/docs/src/api/efficientnet.md @@ -0,0 +1,10 @@ +# Efficient Networks + +```@docs +EfficientNet +EfficientNetv2 +MobileNetv1 +MobileNetv2 +MobileNetv3 +MNASNet +``` \ No newline at end of file diff --git a/docs/src/index.md b/docs/src/index.md index f65893a7a..e90209061 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,9 +1,9 @@ -```@meta -CurrentModule = Metalhead -``` - # Metalhead +[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://fluxml.github.io/Metalhead.jl/dev) +[![CI](https://github.com/FluxML/Metalhead.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/FluxML/Metalhead.jl/actions/workflows/CI.yml) +[![Coverage](https://codecov.io/gh/FluxML/Metalhead.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/FluxML/Metalhead.jl) + [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) provides standard machine learning vision models for use with [Flux.jl](https://fluxml.ai). The architectures in this package make use of pure Flux layers, and they represent the best-practices for creating modules like residual blocks, inception blocks, etc. in Flux. Metalhead also provides some building blocks for more complex models in the Layers module. ## Installation @@ -14,32 +14,32 @@ julia> ]add Metalhead ## Available models -| Model Name | Function | Pre-trained? | -|:-------------------------------------------------------|:----------------------------|:------------:| -| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](@ref) | Y (w/o BN) | -| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](@ref) | Y | -| [WideResNet](https://arxiv.org/abs/1605.07146) | [`WideResNet`](@ref) | Y | -| [GoogLeNet](https://arxiv.org/abs/1409.4842) | [`GoogLeNet`](@ref) | N | -| [Inception-v3](https://arxiv.org/abs/1512.00567) | [`Inceptionv3`](@ref) | N | -| [Inception-v4](https://arxiv.org/abs/1602.07261) | [`Inceptionv4`](@ref) | N | -| [InceptionResNet-v2](https://arxiv.org/abs/1602.07261) | [`Inceptionv3`](@ref) | N | -| [SqueezeNet](https://arxiv.org/abs/1602.07360) | [`SqueezeNet`](@ref) | Y | -| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](@ref) | N | -| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](@ref) | Y | -| [MobileNetv1](https://arxiv.org/abs/1704.04861) | [`MobileNetv1`](@ref) | N | -| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](@ref) | N | -| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](@ref) | N | -| [EfficientNet](https://arxiv.org/abs/1905.11946) | [`EfficientNet`](@ref) | N | -| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](@ref) | N | -| [ResMLP](https://arxiv.org/abs/2105.03404) | [`ResMLP`](@ref) | N | -| [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](@ref) | N | -| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](@ref) | N | -| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](@ref) | N | -| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](@ref) | N | -| [UNet](https://arxiv.org/abs/1505.04597v1) | [`UNet`](@ref) | N | - -To contribute new models, see our [contributing docs](@ref Contributing-to-Metalhead.jl). +| Model Name | Constructor | Pre-trained? | +|:-------------------------------------------------|:-----------------------------------------------------------------------------------------------|:------------:| +| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ConvMixer) | N | +| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ConvNeXt) | N | +| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.DenseNet) | N | +| [EfficientNet](https://arxiv.org/abs/1905.11946) | [`EfficientNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.EfficientNet) | N | +| [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.gMLP) | N | +| [GoogLeNet](https://arxiv.org/abs/1409.4842) | [`GoogLeNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.GoogLeNet) | N | +| [Inception-v3](https://arxiv.org/abs/1512.00567) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.Inceptionv3) | N | +| [Inception-v4](https://arxiv.org/abs/1602.07261) | [`Inceptionv4`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.Inceptionv4) | N | +| [InceptionResNet-v2](https://arxiv.org/abs/1602.07261) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.InceptionResNetv2) | N | +| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MLPMixer) | N | +| [MobileNetv1](https://arxiv.org/abs/1704.04861) | [`MobileNetv1`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv1) | N | +| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv2) | N | +| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv3) | N | +| [ResMLP](https://arxiv.org/abs/2105.03404) | [`ResMLP`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResMLP) | N | +| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResNet) | Y | +| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResNeXt) | Y | +| [SqueezeNet](https://arxiv.org/abs/1602.07360) | [`SqueezeNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.SqueezeNet) | Y | +| [WideResNet](https://arxiv.org/abs/1605.07146) | [`WideResNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.WideResNet) | Y | +| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.VGG) | Y | +| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ViT) | Y | + +To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhead.jl/dev/contributing/). ## Getting Started -You can find the Metalhead.jl getting started guide [here](@ref Quickstart). +You can find the Metalhead.jl getting started guide [here](https://fluxml.ai/Metalhead.jl/dev/tutorials/quickstart/). +; \ No newline at end of file From edfef9077552a5dba9dab877dddd45c158114182 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 May 2023 12:34:15 +0200 Subject: [PATCH 04/10] cleanup --- docs/src/index.md | 45 --------------------------------------------- 1 file changed, 45 deletions(-) delete mode 100644 docs/src/index.md diff --git a/docs/src/index.md b/docs/src/index.md deleted file mode 100644 index e90209061..000000000 --- a/docs/src/index.md +++ /dev/null @@ -1,45 +0,0 @@ -# Metalhead - -[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://fluxml.github.io/Metalhead.jl/dev) -[![CI](https://github.com/FluxML/Metalhead.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/FluxML/Metalhead.jl/actions/workflows/CI.yml) -[![Coverage](https://codecov.io/gh/FluxML/Metalhead.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/FluxML/Metalhead.jl) - -[Metalhead.jl](https://github.com/FluxML/Metalhead.jl) provides standard machine learning vision models for use with [Flux.jl](https://fluxml.ai). The architectures in this package make use of pure Flux layers, and they represent the best-practices for creating modules like residual blocks, inception blocks, etc. in Flux. Metalhead also provides some building blocks for more complex models in the Layers module. - -## Installation - -```julia -julia> ]add Metalhead -``` - -## Available models - -| Model Name | Constructor | Pre-trained? | -|:-------------------------------------------------|:-----------------------------------------------------------------------------------------------|:------------:| -| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ConvMixer) | N | -| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ConvNeXt) | N | -| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.DenseNet) | N | -| [EfficientNet](https://arxiv.org/abs/1905.11946) | [`EfficientNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.EfficientNet) | N | -| [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.gMLP) | N | -| [GoogLeNet](https://arxiv.org/abs/1409.4842) | [`GoogLeNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.GoogLeNet) | N | -| [Inception-v3](https://arxiv.org/abs/1512.00567) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.Inceptionv3) | N | -| [Inception-v4](https://arxiv.org/abs/1602.07261) | [`Inceptionv4`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.Inceptionv4) | N | -| [InceptionResNet-v2](https://arxiv.org/abs/1602.07261) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.InceptionResNetv2) | N | -| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MLPMixer) | N | -| [MobileNetv1](https://arxiv.org/abs/1704.04861) | [`MobileNetv1`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv1) | N | -| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv2) | N | -| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv3) | N | -| [ResMLP](https://arxiv.org/abs/2105.03404) | [`ResMLP`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResMLP) | N | -| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResNet) | Y | -| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResNeXt) | Y | -| [SqueezeNet](https://arxiv.org/abs/1602.07360) | [`SqueezeNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.SqueezeNet) | Y | -| [WideResNet](https://arxiv.org/abs/1605.07146) | [`WideResNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.WideResNet) | Y | -| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.VGG) | Y | -| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ViT) | Y | - -To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhead.jl/dev/contributing/). - -## Getting Started - -You can find the Metalhead.jl getting started guide [here](https://fluxml.ai/Metalhead.jl/dev/tutorials/quickstart/). -; \ No newline at end of file From 8c7f7f699debc9540c5ae4e04e0089cb591b1311 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 May 2023 15:38:59 +0200 Subject: [PATCH 05/10] bson artifact for resnet152 --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 1a4d377b4..455452ca1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ .CondaPkg/ scripts/weights/ scripts/Artifacts.toml +# copied from README.md +docs/src/index.md # manifests docs/Manifest.toml From b0756b63d224baa934db44888e1af97c5e6c18bc Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 May 2023 16:03:32 +0200 Subject: [PATCH 06/10] fixes --- Artifacts.toml | 16 ++++++ README.md | 46 +++++++++------- scripts/manage_huggingface_org.jl | 10 ++-- scripts/port_torchvision.jl | 87 ++++++++++++++++++++----------- src/convnets/resnets/resnet.jl | 2 +- src/convnets/vgg.jl | 4 +- src/layers/drop.jl | 2 +- test/convnets.jl | 23 ++++---- test/runtests.jl | 6 ++- test/vits.jl | 4 +- 10 files changed, 129 insertions(+), 71 deletions(-) diff --git a/Artifacts.toml b/Artifacts.toml index a2013b125..c92fa531d 100644 --- a/Artifacts.toml +++ b/Artifacts.toml @@ -62,6 +62,22 @@ lazy = true sha256 = "a8d30a735ef5649ec40a74a0515ee3d6774499267be06f5f2b372259c5ced8d6" url = "https://huggingface.co/FluxML/resnet152/resolve/main/resnet152.tar.gz" +[resnet152-IMAGENET1K_V1] +git-tree-sha1 = "e03bd8f6ae55ec3fa854b8c987d44801cbd832bc" +lazy = true + + [[resnet152-IMAGENET1K_V1.download]] + sha256 = "01f3f9b30e9ef240885a5e369d4435e47c1c4da0c083af6d96f29ff4e5b5a722" + url = "https://huggingface.co/FluxML/resnet152/resolve/main/resnet152-IMAGENET1K_V1.tar.gz" + +[resnet152-IMAGENET1K_V2] +git-tree-sha1 = "64af7840cd730142074e95bd978acf1f2d99bf2a" +lazy = true + + [[resnet152-IMAGENET1K_V2.download]] + sha256 = "0e3a9ca6b5b48dae81f6e5bab49dba2c0d92e9b2800f041f3767315862e18b3c" + url = "https://huggingface.co/FluxML/resnet152/resolve/main/resnet152-IMAGENET1K_V2.tar.gz" + [resnet18] git-tree-sha1 = "4ced5a0338c0f0293940f1deb63e1c463125a6ff" lazy = true diff --git a/README.md b/README.md index 85975de75..7ea9fdba8 100644 --- a/README.md +++ b/README.md @@ -14,28 +14,34 @@ julia> ]add Metalhead ## Available models +### Image Classification + | Model Name | Constructor | Pre-trained? | |:-------------------------------------------------|:-----------------------------------------------------------------------------------------------|:------------:| -| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ConvNeXt) | N | -| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ConvMixer) | N | -| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.DenseNet) | N | -| [EfficientNet](https://arxiv.org/abs/1905.11946) | [`EfficientNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.EfficientNet) | N | -| [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.gMLP) | N | -| [GoogLeNet](https://arxiv.org/abs/1409.4842) | [`GoogLeNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.GoogLeNet) | N | -| [Inception-v3](https://arxiv.org/abs/1512.00567) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.Inceptionv3) | N | -| [Inception-v4](https://arxiv.org/abs/1602.07261) | [`Inceptionv4`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.Inceptionv4) | N | -| [InceptionResNet-v2](https://arxiv.org/abs/1602.07261) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.InceptionResNetv2) | N | -| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MLPMixer) | N | -| [MobileNetv1](https://arxiv.org/abs/1704.04861) | [`MobileNetv1`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv1) | N | -| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv2) | N | -| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.MobileNetv3) | N | -| [ResMLP](https://arxiv.org/abs/2105.03404) | [`ResMLP`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResMLP) | N | -| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResNet) | Y | -| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ResNeXt) | Y | -| [SqueezeNet](https://arxiv.org/abs/1602.07360) | [`SqueezeNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.SqueezeNet) | Y | -| [WideResNet](https://arxiv.org/abs/1605.07146) | [`WideResNet`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.WideResNet) | Y | -| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.VGG) | Y | -| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](https://fluxml.ai/Metalhead.jl/stable/api/reference.html#Metalhead.ViT) | Y | +| [AlexNet](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf) | [`AlexNet`](https://fluxml.ai/Metalhead.jl/dev/api/other/#Metalhead.AlexNet) | N | +| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/dev/api/hybrid/#Metalhead.ConvMixer) | N | +| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/dev/api/hybrid/#Metalhead.ConvNeXt) | N | +| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/dev/api/densenet/#Metalhead.DenseNet) | N | +| [EfficientNet](https://arxiv.org/abs/1905.11946) | [`EfficientNet`](https://fluxml.ai/Metalhead.jl/dev/api/efficientnet/#Metalhead.EfficientNet) | N | +| [EfficientNetv2](https://arxiv.org/abs/2104.00298) | [`EfficientNetv2`](https://fluxml.ai/Metalhead.jl/dev/api/efficientnet/#Metalhead.EfficientNetv2) | N | +| [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/dev/api/mixers/#Metalhead.gMLP) | N | +| [GoogLeNet](https://arxiv.org/abs/1409.4842) | [`GoogLeNet`](https://fluxml.ai/Metalhead.jl/dev/api/inception/l#Metalhead.GoogLeNet) | N | +| [Inception-v3](https://arxiv.org/abs/1512.00567) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/dev/api/inception/#Metalhead.Inceptionv3) | N | +| [Inception-v4](https://arxiv.org/abs/1602.07261) | [`Inceptionv4`](https://fluxml.ai/Metalhead.jl/dev/api/inception/#Metalhead.Inceptionv4) | N | +| [InceptionResNet-v2](https://arxiv.org/abs/1602.07261) | [`InceptionResNetv2`](https://fluxml.ai/Metalhead.jl/dev/api/inception/#Metalhead.InceptionResNetv2) | N | +| [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/dev/api/mixer/#Metalhead.MLPMixer) | N | +| [MobileNetv1](https://arxiv.org/abs/1704.04861) | [`MobileNetv1`](https://fluxml.ai/Metalhead.jl/dev/api/efficientnet/#Metalhead.MobileNetv1) | N | +| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/dev/api/efficientnet/#Metalhead.MobileNetv2) | N | +| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/dev/api/efficientnet/#Metalhead.MobileNetv3) | N | +| [MNASNet](https://arxiv.org/abs/1807.11626) | [`MNASNet`](https://fluxml.ai/Metalhead.jl/dev/api/efficientnet/#Metalhead.MNASNet) | N | +| [ResMLP](https://arxiv.org/abs/2105.03404) | [`ResMLP`](https://fluxml.ai/Metalhead.jl/dev/api/mixers/#Metalhead.ResMLP) | N | +| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/dev/api/resnet/#Metalhead.ResNet) | Y | +| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/dev/api/resnet/#Metalhead.ResNeXt) | Y | +| [SqueezeNet](https://arxiv.org/abs/1602.07360) | [`SqueezeNet`](https://fluxml.ai/Metalhead.jl/dev/api/others/#Metalhead.SqueezeNet) | Y | +| [Xception](https://arxiv.org/abs/1610.02357) | [`Xception`](https://fluxml.ai/Metalhead.jl/dev/api/inception/#Metalhead.Xception) | N | +| [WideResNet](https://arxiv.org/abs/1605.07146) | [`WideResNet`](https://fluxml.ai/Metalhead.jl/dev/api/resnet/#Metalhead.WideResNet) | Y | +| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/dev/api/others/#Metalhead.VGG) | Y | +| [Vision Transformer](https://arxiv.org/abs/2010.11929) | [`ViT`](https://fluxml.ai/Metalhead.jl/dev/api/vit/#Metalhead.ViT) | Y | To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhead.jl/dev/contributing/). diff --git a/scripts/manage_huggingface_org.jl b/scripts/manage_huggingface_org.jl index f795851e4..22146fc9d 100644 --- a/scripts/manage_huggingface_org.jl +++ b/scripts/manage_huggingface_org.jl @@ -75,14 +75,14 @@ end # hfhub.login(ENV["HUGGINGFACE_TOKEN"]) # model_artifacts = create_model_artifacts(force=false) # model_artifacts = filter(model_artifacts) do x -# !startswith(x[1], "resnet") && !startswith(x[1], "resnext") +# startswith(x[1], "resnet152") # end # upload_artifacts_to_hf(model_artifacts) ### Generate Artifacts.toml from HuggingFace repos ############# fluxml_model_repos = list_fluxml_models() -# fluxml_model_repos = filter(fluxml_model_repos) do repo -# name = split(repo[:id], "/")[2] -# startswith(name, "resnet") -# end +fluxml_model_repos = filter(fluxml_model_repos) do repo + name = split(repo[:id], "/")[2] + startswith(name, "resnet152") +end generate_artifacts_toml(fluxml_model_repos) diff --git a/scripts/port_torchvision.jl b/scripts/port_torchvision.jl index cbdfc75d5..9d7c0bf29 100644 --- a/scripts/port_torchvision.jl +++ b/scripts/port_torchvision.jl @@ -8,10 +8,10 @@ const tvmodels = pyimport("torchvision.models") # name, weight, jlconstructor, pyconstructor model_list = [ - ("vgg11", "IMAGENET1K_V1", () -> VGG(11), weights -> tvmodels.vgg11(; weights)), - ("vgg13", "IMAGENET1K_V1", () -> VGG(13), weights -> tvmodels.vgg13(; weights)), - ("vgg16", "IMAGENET1K_V1", () -> VGG(16), weights -> tvmodels.vgg16(; weights)), - ("vgg19", "IMAGENET1K_V1", () -> VGG(19), weights -> tvmodels.vgg19(; weights)), + ("vgg11", "IMAGENET1K_V1", () -> VGG(11, batchnorm=false), weights -> tvmodels.vgg11(; weights)), + ("vgg13", "IMAGENET1K_V1", () -> VGG(13, batchnorm=false), weights -> tvmodels.vgg13(; weights)), + ("vgg16", "IMAGENET1K_V1", () -> VGG(16, batchnorm=false), weights -> tvmodels.vgg16(; weights)), + ("vgg19", "IMAGENET1K_V1", () -> VGG(19, batchnorm=false), weights -> tvmodels.vgg19(; weights)), ("resnet18", "IMAGENET1K_V1", () -> ResNet(18), weights -> tvmodels.resnet18(; weights)), ("resnet34", "IMAGENET1K_V1", () -> ResNet(34), weights -> tvmodels.resnet34(; weights)), ("resnet50", "IMAGENET1K_V1", () -> ResNet(50), weights -> tvmodels.resnet50(; weights)), @@ -23,6 +23,8 @@ model_list = [ ("resnext101_32x8d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)), ("resnext101_32x8d", "IMAGENET1K_V2", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)), ("resnext101_64x4d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=64, base_width=4), weights -> tvmodels.resnext101_64x4d(; weights)), + ("resnet152", "IMAGENET1K_V1", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)), + ("resnet152", "IMAGENET1K_V2", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)), ("wideresnet50", "IMAGENET1K_V1", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)), ("wideresnet50", "IMAGENET1K_V2", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)), ("wideresnet101", "IMAGENET1K_V1", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(; weights)), @@ -32,40 +34,67 @@ model_list = [ ("vit_l_16", "IMAGENET1K_V1", () -> ViT(:large), weights -> tvmodels.vit_l_16(; weights)), ("vit_l_32", "IMAGENET1K_V1", () -> ViT(:large, patch_size=(32,32)), weights -> tvmodels.vit_l_32(; weights)), ## NOT WORKING: + # ("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(; weights)), + # ("squeezenet1_0", "IMAGENET1K_V1", () -> SqueezeNet(), weights -> tvmodels.squeezenet1_0(; weights)), # ("vit_h_14", "IMAGENET1K_SWAG_E2E_V1", () -> ViT(:huge, imsize=(224,224), patch_size=(14,14), qkv_bias=true), weights -> tvmodels.vit_h_14(; weights)), # ("vit_h_14", "IMAGENET1K_SWAG_LINEAR_V1", () -> ViT(:huge, imsize=(224,224), patch_size=(14,14), qkv_bias=true), weights -> tvmodels.vit_h_14(; weights)), - # ("resnet152", "IMAGENET1K_V2", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)), - # ("resnet152", "IMAGENET1K_V1", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)), - # ("squeezenet1_0", "IMAGENET1K_V1", () -> SqueezeNet(), weights -> tvmodels.squeezenet1_0(; weights)), - # ("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(; weights)), + # ("vgg11_bn", "IMAGENET1K_V1", () -> VGG(11, batchnorm=true), weights -> tvmodels.vgg11_bn(; weights)), + # ("vgg13_bn", "IMAGENET1K_V1", () -> VGG(13, batchnorm=true), weights -> tvmodels.vgg13_bn(; weights)), + # ("vgg16_bn", "IMAGENET1K_V1", () -> VGG(16, batchnorm=true), weights -> tvmodels.vgg16_bn(; weights)), + # ("vgg19_bn", "IMAGENET1K_V1", () -> VGG(19, batchnorm=true), weights -> tvmodels.vgg19_bn(; weights)), ] +function save_model_state(filename, model) + mkpath(dirname(filename)) + if endswith(filename, ".jld2") + JLD2.jldsave(filename, model_state = Flux.state(model)) + elseif endswith(filename, ".bson") + BSON.@save filename model_state=Flux.state(model) + else + error("Unknown file extension") + end +end + +function load_model_state(filename) + if endswith(filename, ".jld2") + return JLD2.load(filename)["model_state"] + elseif endswith(filename, ".bson") + return BSON.load(filename)[:model_state] + else + error("Unknown file extension") + end +end + function convert_models() # name, weights, jlconstructor, pyconstructor = first(model_list) - for (name, weights, jlconstructor, pyconstructor) in model_list - # CONSTRUCT MODELS - jlmodel = jlconstructor() - pymodel = pyconstructor(weights) + for (name, weights, jlconstructor, pyconstructor) in model_list + # CONSTRUCT MODELS + jlmodel = jlconstructor() + pymodel = pyconstructor(weights) - # LOAD WEIGHTS FROM PYTORCH TO JULIA - pytorch2flux!(jlmodel, pymodel) - rtol = startswith(name, "vit") ? 1e-2 : 1e-4 # TODO investigate why ViT is less accurate - compare_pytorch(jlmodel, pymodel; rtol) - - # SAVE WEIGHTS - artifact_name = "$(name)-$weights" - filename = joinpath(@__DIR__, "weights", name, artifact_name, "$(artifact_name).jld2") - mkpath(dirname(filename)) - JLD2.jldsave(filename, model_state = Flux.state(jlmodel)) - println("Saved $filename") + # LOAD WEIGHTS FROM PYTORCH TO JULIA + pytorch2flux!(jlmodel, pymodel) + rtol = startswith(name, "vit") ? 1e-2 : 1e-4 # TODO investigate why ViT is less accurate + compare_pytorch(jlmodel, pymodel; rtol) + + # SAVE WEIGHTS + artifact_name = "$(name)-$weights" + filename = joinpath(@__DIR__, "weights", name, artifact_name, "$(artifact_name)") + if name != "resnet152" + filename *= ".jld2" + else + filename *= ".bson" # TODO: fix resnet152.jld2, not sure why it's not working + end + save_model_state(filename, jlmodel) + println("Saved $filename") - # LOAD WEIGHTS AND TEST AGAIN - jlmodel2 = jlconstructor() - model_state = JLD2.load(filename, "model_state") - Flux.loadmodel!(jlmodel2, model_state) - compare_pytorch(jlmodel2, pymodel; rtol) - end + # LOAD WEIGHTS AND TEST AGAIN + jlmodel2 = jlconstructor() + model_state = load_model_state(filename) + Flux.loadmodel!(jlmodel2, model_state) + compare_pytorch(jlmodel2, pymodel; rtol) + end end convert_models() diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index 5010164aa..e25574c63 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -27,7 +27,7 @@ function ResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, artifact_name = "resnet$(depth)" if depth ∈ [18, 34] artifact_name *= "-IMAGENET1K_V1" - elseif depth ∈ [50, 101] + elseif depth ∈ [50, 101, 152] artifact_name *= "-IMAGENET1K_V2" end loadpretrain!(model, artifact_name) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 29336021a..5e5feed62 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -166,7 +166,9 @@ function VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false, if pretrain artifact_name = string("vgg", depth) if batchnorm - artifact_name *= "-bn" + artifact_name *= "_bn" + else + artifact_name *= "-IMAGENET1K_V1" end loadpretrain!(model, artifact_name) end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 15f8e7533..752668c92 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -97,7 +97,7 @@ ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_block_prob, gamma_s function (m::DropBlock)(x) _dropblock_checks(x, m.drop_block_prob, m.gamma_scale) - return Flux._isactive(m) ? + return Flux._isactive(m, x) ? dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) : x end diff --git a/test/convnets.jl b/test/convnets.jl index 5a1392206..774a44b25 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -55,18 +55,19 @@ end end end end +end - @testset "WideResNet" begin - @testset "WideResNet($sz)" for sz in [50, 101] - m = WideResNet(sz) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() - if (WideResNet, sz) in PRETRAINED_MODELS - @test acctest(WideResNet(sz, pretrain = true)) - else - @test_throws ArgumentError WideResNet(sz, pretrain = true) - end + +@testset "WideResNet" begin + @testset "WideResNet($sz)" for sz in [50, 101] + m = WideResNet(sz) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + if (WideResNet, sz) in PRETRAINED_MODELS + @test acctest(WideResNet(sz, pretrain = true)) + else + @test_throws ArgumentError WideResNet(sz, pretrain = true) end end end diff --git a/test/runtests.jl b/test/runtests.jl index cd5c8ab99..1a77f9a77 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,7 +18,11 @@ const PRETRAINED_MODELS = [ (WideResNet, 101), (ResNeXt, 50, 32, 4), (ResNeXt, 101, 64, 4), - (ResNeXt, 101, 32, 8) + (ResNeXt, 101, 32, 8), + (ViT, :base, (16, 16)), + (ViT, :base, (32, 32)), + (ViT, :large, (16, 16)), + (ViT, :large, (32, 32)), ] function _gc() diff --git a/test/vits.jl b/test/vits.jl index 7561cfdb5..76a606bc9 100644 --- a/test/vits.jl +++ b/test/vits.jl @@ -1,8 +1,8 @@ @testset "ViT" begin for config in [:tiny, :small, :base, :large, :huge] # :giant, :gigantic] m = ViT(config) - @test size(m(x_256)) == (1000, 1) - @test gradtest(m, x_256) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) _gc() end end From 1e645a860342515cb26f475d63eababbdd04a10e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 May 2023 16:13:56 +0200 Subject: [PATCH 07/10] fix wideresnet --- src/convnets/hybrid/convnext.jl | 2 +- src/convnets/inceptions/googlenet.jl | 5 +++-- src/convnets/inceptions/inceptionresnetv2.jl | 5 +++-- src/convnets/inceptions/inceptionv3.jl | 5 +++-- src/convnets/inceptions/inceptionv4.jl | 5 +++-- src/convnets/inceptions/xception.jl | 5 +++-- src/convnets/resnets/res2net.jl | 4 ++-- src/convnets/resnets/resnet.jl | 2 +- src/convnets/resnets/seresnet.jl | 7 ++++--- src/mixers/gmlp.jl | 5 +++-- src/mixers/mlpmixer.jl | 5 +++-- src/mixers/resmlp.jl | 5 +++-- test/convnets.jl | 2 +- 13 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/convnets/hybrid/convnext.jl b/src/convnets/hybrid/convnext.jl index 9b9b3f326..d7f56855a 100644 --- a/src/convnets/hybrid/convnext.jl +++ b/src/convnets/hybrid/convnext.jl @@ -129,7 +129,7 @@ function ConvNeXt(config::Symbol; pretrain::Bool = false, inchannels::Integer = layers = convnext(config; inchannels, nclasses) model = ConvNeXt(layers) if pretrain - loadpretrain!(layers, "convnext_$config") + loadpretrain!(model, "convnext_$config") end return model end diff --git a/src/convnets/inceptions/googlenet.jl b/src/convnets/inceptions/googlenet.jl index a84af56c4..6ffa2a8ac 100644 --- a/src/convnets/inceptions/googlenet.jl +++ b/src/convnets/inceptions/googlenet.jl @@ -99,10 +99,11 @@ end function GoogLeNet(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000, batchnorm::Bool = false, bias::Bool = true) layers = googlenet(; inchannels, nclasses, batchnorm, bias) + model = GoogLeNet(layers) if pretrain - loadpretrain!(layers, "GoogLeNet") + loadpretrain!(model, "GoogLeNet") end - return GoogLeNet(layers) + return model end (m::GoogLeNet)(x) = m.layers(x) diff --git a/src/convnets/inceptions/inceptionresnetv2.jl b/src/convnets/inceptions/inceptionresnetv2.jl index 98d686062..180de9574 100644 --- a/src/convnets/inceptions/inceptionresnetv2.jl +++ b/src/convnets/inceptions/inceptionresnetv2.jl @@ -122,10 +122,11 @@ end function InceptionResNetv2(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = inceptionresnetv2(; inchannels, nclasses) + model = InceptionResNetv2(layers) if pretrain - loadpretrain!(layers, "InceptionResNetv2") + loadpretrain!(model, "InceptionResNetv2") end - return InceptionResNetv2(layers) + return model end (m::InceptionResNetv2)(x) = m.layers(x) diff --git a/src/convnets/inceptions/inceptionv3.jl b/src/convnets/inceptions/inceptionv3.jl index 41d7ae18e..b295bf3fd 100644 --- a/src/convnets/inceptions/inceptionv3.jl +++ b/src/convnets/inceptions/inceptionv3.jl @@ -183,10 +183,11 @@ end function Inceptionv3(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = inceptionv3(; inchannels, nclasses) + model = Inceptionv3(layers) if pretrain - loadpretrain!(layers, "Inceptionv3") + loadpretrain!(model, "Inceptionv3") end - return Inceptionv3(layers) + return model end (m::Inceptionv3)(x) = m.layers(x) diff --git a/src/convnets/inceptions/inceptionv4.jl b/src/convnets/inceptions/inceptionv4.jl index 964afc362..e00b66523 100644 --- a/src/convnets/inceptions/inceptionv4.jl +++ b/src/convnets/inceptions/inceptionv4.jl @@ -137,10 +137,11 @@ end function Inceptionv4(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = inceptionv4(; inchannels, nclasses) + model = Inceptionv4(layers) if pretrain - loadpretrain!(layers, "Inceptionv4") + loadpretrain!(model, "Inceptionv4") end - return Inceptionv4(layers) + return model end (m::Inceptionv4)(x) = m.layers(x) diff --git a/src/convnets/inceptions/xception.jl b/src/convnets/inceptions/xception.jl index 9dfd73f86..bf7a50816 100644 --- a/src/convnets/inceptions/xception.jl +++ b/src/convnets/inceptions/xception.jl @@ -93,10 +93,11 @@ end function Xception(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) layers = xception(; inchannels, nclasses) + model = Xception(layers) if pretrain - loadpretrain!(layers, "xception") + loadpretrain!(model, "xception") end - return Xception(layers) + return model end (m::Xception)(x) = m.layers(x) diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index df7f2c98d..94bdb7bcd 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -37,7 +37,7 @@ function Res2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4, model = Res2Net(layers) if pretrain artifact_name = string("res2net", depth, "_", base_width, "x", scale) - loadpretrain!(layers, artifact_name) + loadpretrain!(model, artifact_name) end return model end @@ -86,7 +86,7 @@ function Res2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4, model = Res2NeXt(layers) if pretrain artifact_name = string("res2next", depth, "_", base_width, "x", scale, "x", cardinality) - loadpretrain!(layers, artifact_name) + loadpretrain!(model, artifact_name) end return model end diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index e25574c63..8f15d0471 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -72,7 +72,7 @@ function WideResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer if depth ∈ [50, 101] artifact_name *= "-IMAGENET1K_V2" end - loadpretrain!(layers, artifact_name) + loadpretrain!(model, artifact_name) end return model end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 4be519720..883ca64be 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -30,7 +30,7 @@ function SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = model = SEResNet(layers) if pretrain artifact_name = "seresnet$(depth)" - loadpretrain!(layers, artifact_name) + loadpretrain!(model, artifact_name) end return model end @@ -74,10 +74,11 @@ function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer layers = resnet(LRESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width, attn_fn = squeeze_excite) + model = SEResNeXt(layers) if pretrain - loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width)) + loadpretrain!(model, string("seresnext", depth, "_", cardinality, "x", base_width)) end - return SEResNeXt(layers) + return model end (m::SEResNeXt)(x) = m.layers(x) diff --git a/src/mixers/gmlp.jl b/src/mixers/gmlp.jl index 6319bfef6..e2b9e62f5 100644 --- a/src/mixers/gmlp.jl +++ b/src/mixers/gmlp.jl @@ -89,10 +89,11 @@ function gMLP(config::Symbol; imsize::Dims{2} = (224, 224), patch_size::Dims{2} _checkconfig(config, keys(MIXER_CONFIGS)) layers = mlpmixer(spatialgatingblock, imsize; mlp_layer = gated_mlp_block, patch_size, MIXER_CONFIGS[config]..., inchannels, nclasses) + model = gMLP(layers) if pretrain - loadpretrain!(layers, string("gmlp", config)) + loadpretrain!(model, string("gmlp", config)) end - return gMLP(layers) + return model end (m::gMLP)(x) = m.layers(x) diff --git a/src/mixers/mlpmixer.jl b/src/mixers/mlpmixer.jl index 24656bd63..1625b12d4 100644 --- a/src/mixers/mlpmixer.jl +++ b/src/mixers/mlpmixer.jl @@ -62,10 +62,11 @@ function MLPMixer(config::Symbol; imsize::Dims{2} = (224, 224), _checkconfig(config, keys(MIXER_CONFIGS)) layers = mlpmixer(mixerblock, imsize; patch_size, MIXER_CONFIGS[config]..., inchannels, nclasses) + model = MLPMixer(layers) if pretrain - loadpretrain!(layers, string("mlpmixer", config)) + loadpretrain!(model, string("mlpmixer", config)) end - return MLPMixer(layers) + return model end (m::MLPMixer)(x) = m.layers(x) diff --git a/src/mixers/resmlp.jl b/src/mixers/resmlp.jl index b1ff44ea1..e659f4650 100644 --- a/src/mixers/resmlp.jl +++ b/src/mixers/resmlp.jl @@ -61,10 +61,11 @@ function ResMLP(config::Symbol; imsize::Dims{2} = (224, 224), _checkconfig(config, keys(MIXER_CONFIGS)) layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, MIXER_CONFIGS[config]..., inchannels, nclasses) + model = ResMLP(layers) if pretrain - loadpretrain!(layers, string(resmlp, config)) + loadpretrain!(model, string(resmlp, config)) end - return ResMLP(layers) + return model end (m::ResMLP)(x) = m.layers(x) diff --git a/test/convnets.jl b/test/convnets.jl index 774a44b25..182373cf5 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -59,7 +59,7 @@ end @testset "WideResNet" begin - @testset "WideResNet($sz)" for sz in [50, 101] + @testset "WideResNet($sz)" for sz in [50] m = WideResNet(sz) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) From ee9a26be6aa6a2ec9ca42aadaa8e3b5097fc175c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 May 2023 16:23:21 +0200 Subject: [PATCH 08/10] add UNet to readme --- README.md | 15 ++++++++++----- docs/src/api/others.md | 2 ++ docs/src/contributing.md | 3 ++- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 7ea9fdba8..406aa4454 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,14 @@ julia> ]add Metalhead ``` +## Getting Started + +You can find the Metalhead.jl getting started guide [here](https://fluxml.ai/Metalhead.jl/dev/tutorials/quickstart/). + ## Available models +To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhead.jl/dev/contributing/). + ### Image Classification | Model Name | Constructor | Pre-trained? | @@ -43,9 +49,8 @@ julia> ]add Metalhead | [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/dev/api/others/#Metalhead.VGG) | Y | | [Vision Transformer](https://arxiv.org/abs/2010.11929) | [`ViT`](https://fluxml.ai/Metalhead.jl/dev/api/vit/#Metalhead.ViT) | Y | -To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhead.jl/dev/contributing/). - -## Getting Started +### Other Models -You can find the Metalhead.jl getting started guide [here](https://fluxml.ai/Metalhead.jl/dev/tutorials/quickstart/). -; \ No newline at end of file +| Model Name | Constructor | Pre-trained? | +|:-------------------------------------------------|:-----------------------------------------------------------------------------------------------|:------------:| +| [UNet](https://arxiv.org/abs/1505.04597) | [`UNet`](https://fluxml.ai/Metalhead.jl/dev/api/others/#Metalhead.UNet) | N | \ No newline at end of file diff --git a/docs/src/api/others.md b/docs/src/api/others.md index 9cadfaee4..ebfda7f41 100644 --- a/docs/src/api/others.md +++ b/docs/src/api/others.md @@ -8,6 +8,7 @@ This is the API reference for some of the other models supported by Metalhead.jl AlexNet VGG SqueezeNet +UNet ``` ## The mid-level functions @@ -16,4 +17,5 @@ SqueezeNet Metalhead.alexnet Metalhead.vgg Metalhead.squeezenet +Metalhead.unet ``` diff --git a/docs/src/contributing.md b/docs/src/contributing.md index 8f8f5376f..5ca816b46 100644 --- a/docs/src/contributing.md +++ b/docs/src/contributing.md @@ -28,7 +28,6 @@ To add pre-trained weights for an existing model or new model, you can [open a P All Metalhead.jl model artifacts are hosted using HuggingFace. You can find the FluxML account [here](https://huggingface.co/FluxML). This [documentation from HuggingFace](https://huggingface.co/docs/hub/models) will provide you with an introduction to their ModelHub. In short, the Model Hub is a collection of Git repositories, similar to Julia packages on GitHub. This means you can [make a pull request to our HuggingFace repositories](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) to upload updated weight artifacts just like you would make a PR on GitHub to upload code. 1. Train your model or port the weights from another framework. - - In the `scripts/` folder you can find some code to help you port weights from other frameworks. 2. Save the model using [BSON.jl](https://github.com/JuliaIO/BSON.jl) with `BSON.@save "modelname.bson" model`. It is important that your model is saved under the key `model`. Note that due to the way this process works, to maintain compatibility with different Julia versions, the model must be saved using the LTS version of Julia (currently 1.6). 3. Compress the saved model as a tarball using `tar -cvzf modelname.tar.gz modelname.bson`. @@ -40,3 +39,5 @@ process works, to maintain compatibility with different Julia versions, the mode 9. If the tests pass for your weights, we will merge your PR! Your model should pass the `acctest` function in the Metalhead.jl test suite. If your model already exists in the repo, then these tests are already in place, and you can add your model configuration to the `PRETRAINED_MODELS` list in the `runtests.jl` file. Please refer to the ResNet tests as an example. If you want to fix existing weights, then you can follow the same set of steps. + +See the [scripts/](https://github.com/FluxML/Metalhead.jl/tree/master/scripts) folder in the repo for some helpful scripts that can be used to automate some of these steps. From bf9b0aab6ff4d81b2681ef0c935d4422203ba537 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 May 2023 19:08:47 +0200 Subject: [PATCH 09/10] load old densenet and squeezenet weights --- README.md | 2 +- scripts/port_torchvision.jl | 54 ++++++++++++++++++------------------- scripts/pytorch2flux.jl | 5 ++-- src/convnets/densenet.jl | 8 +++++- src/convnets/squeezenet.jl | 5 ++++ src/pretrain.jl | 14 +++++++++- test/convnets.jl | 2 +- test/runtests.jl | 18 ++++++++----- 8 files changed, 68 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 406aa4454..3a4b7365c 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhea | [AlexNet](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf) | [`AlexNet`](https://fluxml.ai/Metalhead.jl/dev/api/other/#Metalhead.AlexNet) | N | | [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/dev/api/hybrid/#Metalhead.ConvMixer) | N | | [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/dev/api/hybrid/#Metalhead.ConvNeXt) | N | -| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/dev/api/densenet/#Metalhead.DenseNet) | N | +| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/dev/api/densenet/#Metalhead.DenseNet) | Y | | [EfficientNet](https://arxiv.org/abs/1905.11946) | [`EfficientNet`](https://fluxml.ai/Metalhead.jl/dev/api/efficientnet/#Metalhead.EfficientNet) | N | | [EfficientNetv2](https://arxiv.org/abs/2104.00298) | [`EfficientNetv2`](https://fluxml.ai/Metalhead.jl/dev/api/efficientnet/#Metalhead.EfficientNetv2) | N | | [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/dev/api/mixers/#Metalhead.gMLP) | N | diff --git a/scripts/port_torchvision.jl b/scripts/port_torchvision.jl index 9d7c0bf29..76e26b7df 100644 --- a/scripts/port_torchvision.jl +++ b/scripts/port_torchvision.jl @@ -8,33 +8,33 @@ const tvmodels = pyimport("torchvision.models") # name, weight, jlconstructor, pyconstructor model_list = [ - ("vgg11", "IMAGENET1K_V1", () -> VGG(11, batchnorm=false), weights -> tvmodels.vgg11(; weights)), - ("vgg13", "IMAGENET1K_V1", () -> VGG(13, batchnorm=false), weights -> tvmodels.vgg13(; weights)), - ("vgg16", "IMAGENET1K_V1", () -> VGG(16, batchnorm=false), weights -> tvmodels.vgg16(; weights)), - ("vgg19", "IMAGENET1K_V1", () -> VGG(19, batchnorm=false), weights -> tvmodels.vgg19(; weights)), - ("resnet18", "IMAGENET1K_V1", () -> ResNet(18), weights -> tvmodels.resnet18(; weights)), - ("resnet34", "IMAGENET1K_V1", () -> ResNet(34), weights -> tvmodels.resnet34(; weights)), - ("resnet50", "IMAGENET1K_V1", () -> ResNet(50), weights -> tvmodels.resnet50(; weights)), - ("resnet50", "IMAGENET1K_V2", () -> ResNet(50), weights -> tvmodels.resnet50(; weights)), - ("resnet101", "IMAGENET1K_V1", () -> ResNet(101), weights -> tvmodels.resnet101(; weights)), - ("resnet101", "IMAGENET1K_V2", () -> ResNet(101), weights -> tvmodels.resnet101(; weights)), - ("resnext50_32x4d", "IMAGENET1K_V1", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(; weights)), - ("resnext50_32x4d", "IMAGENET1K_V2", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(; weights)), - ("resnext101_32x8d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)), - ("resnext101_32x8d", "IMAGENET1K_V2", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)), - ("resnext101_64x4d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=64, base_width=4), weights -> tvmodels.resnext101_64x4d(; weights)), - ("resnet152", "IMAGENET1K_V1", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)), - ("resnet152", "IMAGENET1K_V2", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)), - ("wideresnet50", "IMAGENET1K_V1", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)), - ("wideresnet50", "IMAGENET1K_V2", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)), - ("wideresnet101", "IMAGENET1K_V1", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(; weights)), - ("wideresnet101", "IMAGENET1K_V2", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(; weights)), - ("vit_b_16", "IMAGENET1K_V1", () -> ViT(:base), weights -> tvmodels.vit_b_16(; weights)), - ("vit_b_32", "IMAGENET1K_V1", () -> ViT(:base, patch_size=(32,32)), weights -> tvmodels.vit_b_32(; weights)), - ("vit_l_16", "IMAGENET1K_V1", () -> ViT(:large), weights -> tvmodels.vit_l_16(; weights)), - ("vit_l_32", "IMAGENET1K_V1", () -> ViT(:large, patch_size=(32,32)), weights -> tvmodels.vit_l_32(; weights)), + # ("vgg11", "IMAGENET1K_V1", () -> VGG(11, batchnorm=false), weights -> tvmodels.vgg11(; weights)), + # ("vgg13", "IMAGENET1K_V1", () -> VGG(13, batchnorm=false), weights -> tvmodels.vgg13(; weights)), + # ("vgg16", "IMAGENET1K_V1", () -> VGG(16, batchnorm=false), weights -> tvmodels.vgg16(; weights)), + # ("vgg19", "IMAGENET1K_V1", () -> VGG(19, batchnorm=false), weights -> tvmodels.vgg19(; weights)), + # ("resnet18", "IMAGENET1K_V1", () -> ResNet(18), weights -> tvmodels.resnet18(; weights)), + # ("resnet34", "IMAGENET1K_V1", () -> ResNet(34), weights -> tvmodels.resnet34(; weights)), + # ("resnet50", "IMAGENET1K_V1", () -> ResNet(50), weights -> tvmodels.resnet50(; weights)), + # ("resnet50", "IMAGENET1K_V2", () -> ResNet(50), weights -> tvmodels.resnet50(; weights)), + # ("resnet101", "IMAGENET1K_V1", () -> ResNet(101), weights -> tvmodels.resnet101(; weights)), + # ("resnet101", "IMAGENET1K_V2", () -> ResNet(101), weights -> tvmodels.resnet101(; weights)), + # ("resnext50_32x4d", "IMAGENET1K_V1", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(; weights)), + # ("resnext50_32x4d", "IMAGENET1K_V2", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(; weights)), + # ("resnext101_32x8d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)), + # ("resnext101_32x8d", "IMAGENET1K_V2", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)), + # ("resnext101_64x4d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=64, base_width=4), weights -> tvmodels.resnext101_64x4d(; weights)), + # ("resnet152", "IMAGENET1K_V1", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)), + # ("resnet152", "IMAGENET1K_V2", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)), + # ("wideresnet50", "IMAGENET1K_V1", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)), + # ("wideresnet50", "IMAGENET1K_V2", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)), + # ("wideresnet101", "IMAGENET1K_V1", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(; weights)), + # ("wideresnet101", "IMAGENET1K_V2", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(; weights)), + # ("vit_b_16", "IMAGENET1K_V1", () -> ViT(:base), weights -> tvmodels.vit_b_16(; weights)), + # ("vit_b_32", "IMAGENET1K_V1", () -> ViT(:base, patch_size=(32,32)), weights -> tvmodels.vit_b_32(; weights)), + # ("vit_l_16", "IMAGENET1K_V1", () -> ViT(:large), weights -> tvmodels.vit_l_16(; weights)), + # ("vit_l_32", "IMAGENET1K_V1", () -> ViT(:large, patch_size=(32,32)), weights -> tvmodels.vit_l_32(; weights)), ## NOT WORKING: - # ("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(; weights)), + ("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(; weights)), # ("squeezenet1_0", "IMAGENET1K_V1", () -> SqueezeNet(), weights -> tvmodels.squeezenet1_0(; weights)), # ("vit_h_14", "IMAGENET1K_SWAG_E2E_V1", () -> ViT(:huge, imsize=(224,224), patch_size=(14,14), qkv_bias=true), weights -> tvmodels.vit_h_14(; weights)), # ("vit_h_14", "IMAGENET1K_SWAG_LINEAR_V1", () -> ViT(:huge, imsize=(224,224), patch_size=(14,14), qkv_bias=true), weights -> tvmodels.vit_h_14(; weights)), @@ -67,7 +67,7 @@ function load_model_state(filename) end function convert_models() - # name, weights, jlconstructor, pyconstructor = first(model_list) + name, weights, jlconstructor, pyconstructor = first(model_list) for (name, weights, jlconstructor, pyconstructor) in model_list # CONSTRUCT MODELS jlmodel = jlconstructor() diff --git a/scripts/pytorch2flux.jl b/scripts/pytorch2flux.jl index d3358905d..92f686474 100644 --- a/scripts/pytorch2flux.jl +++ b/scripts/pytorch2flux.jl @@ -145,8 +145,9 @@ function pytorch2flux!(jlmodel, pymodel; verb=false) end for ((flux_key, flux_param), (pytorch_key, pytorch_param)) in zip(jlstate, pystate) - # @show flux_key size(flux_param) pytorch_key size(pytorch_param) - # @show size(flux_param) == size(pytorch_param) + println("##") + @show flux_key size(flux_param) pytorch_key size(pytorch_param) + @show size(flux_param) == size(pytorch_param) param_name = split(flux_key, ".")[end] diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index de326decc..71cd44bc7 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -151,7 +151,7 @@ function DenseNet(config::Int; pretrain::Bool = false, growth_rate::Int = 32, model = DenseNet(layers) if pretrain artifact_name = string("densenet", config) - loadpretrain!(model, artifact_name) + loadpretrain!(model, artifact_name) # see also HACK below end return model end @@ -160,3 +160,9 @@ end backbone(m::DenseNet) = m.layers[1] classifier(m::DenseNet) = m.layers[2] + +## HACK TO LOAD OLD WEIGHTS, remove when we have a new artifact +function Flux.loadmodel!(m::DenseNet, src) + Flux.loadmodel!(m.layers[1], src.layers[1]) + Flux.loadmodel!(m.layers[2], src.layers[2]) +end diff --git a/src/convnets/squeezenet.jl b/src/convnets/squeezenet.jl index 8629daadc..f6b620f97 100644 --- a/src/convnets/squeezenet.jl +++ b/src/convnets/squeezenet.jl @@ -84,3 +84,8 @@ end backbone(m::SqueezeNet) = m.layers[1] classifier(m::SqueezeNet) = m.layers[2:end] + +function Flux.loadmodel!(model::SqueezeNet, w) + Flux.loadmodel!(model.layers[1], w.layers[1]) + Flux.loadmodel!(model.layers[2], w.layers[2]) +end diff --git a/src/pretrain.jl b/src/pretrain.jl index 2d2c637f4..4403d6341 100644 --- a/src/pretrain.jl +++ b/src/pretrain.jl @@ -9,7 +9,19 @@ function loadweights(artifact_name) catch e throw(ArgumentError("No pre-trained weights available for $artifact_name.")) end - file_name = readdir(artifact_dir)[1] + if length(readdir(artifact_dir)) > 1 + # @warn("Found multiple files in $artifact_dir.") + files = readdir(artifact_dir) + files = filter!(x -> endswith(x, ".bson") || endswith(x, ".jld2"), files) + files = filter!(x -> !startswith(x, "."), files) + if length(files) > 1 + throw(ErrorException("Found multiple weight artifacts for $artifact_name.")) + end + file_name = files[1] + else + file_name = readdir(artifact_dir)[1] + end + file_path = joinpath(artifact_dir, file_name) if endswith(file_name, ".bson") diff --git a/test/convnets.jl b/test/convnets.jl index 182373cf5..774a44b25 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -59,7 +59,7 @@ end @testset "WideResNet" begin - @testset "WideResNet($sz)" for sz in [50] + @testset "WideResNet($sz)" for sz in [50, 101] m = WideResNet(sz) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) diff --git a/test/runtests.jl b/test/runtests.jl index 1a77f9a77..82fd9f42c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,25 +4,29 @@ using Flux: Zygote using Images const PRETRAINED_MODELS = [ - (VGG, 11, false), - (VGG, 13, false), - (VGG, 16, false), - (VGG, 19, false), - SqueezeNet, + (DenseNet, 121), + (DenseNet, 161), + (DenseNet, 169), + (DenseNet, 201), (ResNet, 18), (ResNet, 34), (ResNet, 50), (ResNet, 101), (ResNet, 152), - (WideResNet, 50), - (WideResNet, 101), (ResNeXt, 50, 32, 4), (ResNeXt, 101, 64, 4), (ResNeXt, 101, 32, 8), + SqueezeNet, + (WideResNet, 50), + (WideResNet, 101), (ViT, :base, (16, 16)), (ViT, :base, (32, 32)), (ViT, :large, (16, 16)), (ViT, :large, (32, 32)), + (VGG, 11, false), + (VGG, 13, false), + (VGG, 16, false), + (VGG, 19, false), ] function _gc() From 1ba4967b0dd750944e257f0be380c7900ff6c11b Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 9 May 2023 08:00:18 +0200 Subject: [PATCH 10/10] fix resnext101_64x4 --- src/convnets/resnets/resnext.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index f30f69c7c..e4cd554ad 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -40,7 +40,7 @@ function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = elseif depth == 101 && cardinality == 32 && base_width == 8 artifact_name *= "-IMAGENET1K_V2" elseif depth == 101 && cardinality == 64 && base_width == 4 - artifact_name *= "-IMAGENET1K_V2" + artifact_name *= "-IMAGENET1K_V1" end loadpretrain!(model, artifact_name) end