-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #696 from LuxDL/ap/more_boltz_updates
More API updates
- Loading branch information
Showing
19 changed files
with
97 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,6 @@ | |
|
||
Backend for Lux.jl | ||
|
||
```@meta | ||
CurrentModule = LuxLib | ||
``` | ||
|
||
## Index | ||
|
||
```@index | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,80 +1,9 @@ | ||
```@meta | ||
CurrentModule = Boltz | ||
``` | ||
|
||
# Boltz | ||
|
||
Accelerate ⚡ your ML research using pre-built Deep Learning Models with Lux. | ||
|
||
## Index | ||
|
||
```@index | ||
Pages = ["Boltz.md"] | ||
``` | ||
|
||
## `Layers` API | ||
|
||
```@docs | ||
Layers.ConvBatchNormActivation | ||
Layers.ClassTokens | ||
Layers.MultiHeadSelfAttention | ||
Layers.ViPosEmbedding | ||
Layers.VisionTransformerEncoder | ||
``` | ||
|
||
## Computer Vision Models (`Vision` API) | ||
|
||
### Native Lux Models | ||
|
||
```@docs | ||
Vision.VGG | ||
Vision.VisionTransformer | ||
``` | ||
|
||
### Imported from Metalhead.jl | ||
|
||
!!! tip | ||
|
||
You need to load `Flux` and `Metalhead` before using these models. | ||
|
||
```@docs | ||
Vision.AlexNet | ||
Vision.ConvMixer | ||
Vision.DenseNet | ||
Vision.GoogLeNet | ||
Vision.MobileNet | ||
Vision.ResNet | ||
Vision.ResNeXt | ||
``` | ||
|
||
### Pretrained Models | ||
|
||
!!! tip | ||
|
||
Pass `pretrained=true` to the model constructor to load the pretrained weights. | ||
|
||
|
||
| MODEL | TOP 1 ACCURACY (%) | TOP 5 ACCURACY (%) | | ||
| :------------------------ | :----------------: | :----------------: | | ||
| `AlexNet()` | 54.48 | 77.72 | | ||
| `VGG(11)` | 67.35 | 87.91 | | ||
| `VGG(13)` | 68.40 | 88.48 | | ||
| `VGG(16)` | 70.24 | 89.80 | | ||
| `VGG(19)` | 71.09 | 90.27 | | ||
| `VGG(11; batchnorm=true)` | 69.09 | 88.94 | | ||
| `VGG(13; batchnorm=true)` | 69.66 | 89.49 | | ||
| `VGG(16; batchnorm=true)` | 72.11 | 91.02 | | ||
| `VGG(19; batchnorm=true)` | 72.95 | 91.32 | | ||
|
||
#### Preprocessing | ||
|
||
All the pretrained models require that the images be normalized with the parameters | ||
`mean = [0.485f0, 0.456f0, 0.406f0]` and `std = [0.229f0, 0.224f0, 0.225f0]`. | ||
|
||
## Non-Public API | ||
|
||
```@docs | ||
Boltz._seconddimmean | ||
Boltz._fast_chunk | ||
Boltz._flatten_spatial | ||
Pages = ["Boltz.md", "Boltz_Layers.md", "Boltz_Vision.md"] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# `Boltz.Layers` and `Boltz.Basis` API Reference | ||
|
||
## `Layers` API | ||
|
||
```@docs | ||
Layers.ConvBatchNormActivation | ||
Layers.ConvNormActivation | ||
Layers.ClassTokens | ||
Layers.HamiltonianNN | ||
Layers.MultiHeadSelfAttention | ||
Layers.MLP | ||
Layers.TensorProductLayer | ||
Layers.ViPosEmbedding | ||
Layers.VisionTransformerEncoder | ||
``` | ||
|
||
## Basis Functions | ||
|
||
!!! warning | ||
|
||
The function calls for these basis functions should be considered experimental and are | ||
subject to change without deprecation. However, the functions themselves are stable | ||
and can be freely used in combination with the other Layers and Models. | ||
|
||
```@docs | ||
Basis.Cos | ||
Basis.Chebyshev | ||
Basis.Fourier | ||
Basis.Legendre | ||
Basis.Polynomial | ||
Basis.Sin | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Boltz.jl Private API | ||
|
||
```@docs | ||
Boltz._seconddimmean | ||
Boltz._should_type_assert | ||
Boltz._fast_chunk | ||
Boltz._flatten_spatial | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Computer Vision Models (`Vision` API) | ||
|
||
## Native Lux Models | ||
|
||
```@docs | ||
Vision.VGG | ||
Vision.VisionTransformer | ||
``` | ||
|
||
## Imported from Metalhead.jl | ||
|
||
!!! tip | ||
|
||
You need to load `Flux` and `Metalhead` before using these models. | ||
|
||
```@docs | ||
Vision.AlexNet | ||
Vision.ConvMixer | ||
Vision.DenseNet | ||
Vision.GoogLeNet | ||
Vision.MobileNet | ||
Vision.ResNet | ||
Vision.ResNeXt | ||
``` | ||
|
||
## Pretrained Models | ||
|
||
!!! tip | ||
|
||
Pass `pretrained=true` to the model constructor to load the pretrained weights. | ||
|
||
|
||
| MODEL | TOP 1 ACCURACY (%) | TOP 5 ACCURACY (%) | | ||
| :------------------------ | :----------------: | :----------------: | | ||
| `AlexNet()` | 54.48 | 77.72 | | ||
| `VGG(11)` | 67.35 | 87.91 | | ||
| `VGG(13)` | 68.40 | 88.48 | | ||
| `VGG(16)` | 70.24 | 89.80 | | ||
| `VGG(19)` | 71.09 | 90.27 | | ||
| `VGG(11; batchnorm=true)` | 69.09 | 88.94 | | ||
| `VGG(13; batchnorm=true)` | 69.66 | 89.49 | | ||
| `VGG(16; batchnorm=true)` | 72.11 | 91.02 | | ||
| `VGG(19; batchnorm=true)` | 72.95 | 91.32 | | ||
|
||
### Preprocessing | ||
|
||
All the pretrained models require that the images be normalized with the parameters | ||
`mean = [0.485f0, 0.456f0, 0.406f0]` and `std = [0.229f0, 0.224f0, 0.225f0]`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,5 @@ | ||
# Built-In Layers | ||
|
||
```@meta | ||
CurrentModule = Lux | ||
``` | ||
|
||
## Index | ||
|
||
```@index | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,5 @@ | ||
# Utilities | ||
|
||
```@meta | ||
CurrentModule = Lux | ||
``` | ||
|
||
## Index | ||
|
||
```@index | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,3 @@ | ||
```@meta | ||
CurrentModule = LuxTestUtils | ||
``` | ||
|
||
# LuxTestUtils | ||
|
||
!!! warning | ||
|
edda7ec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128)
3650.5
ns3700.75
ns0.99
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128)
7205.4
ns7173.416666666666
ns1.00
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128)
20689
ns21129
ns0.98
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128)
9711.1
ns9760.75
ns0.99
Dense(2 => 2)/cpu/reverse/Flux/(2, 128)
9014.8
ns9085
ns0.99
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128)
4457.125
ns4523.5
ns0.99
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128)
1170.89453125
ns1158.1818181818182
ns1.01
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128)
1179.6129032258063
ns1117.313725490196
ns1.06
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128)
1174.5255474452554
ns1190.6204379562043
ns0.99
Dense(2 => 2)/cpu/forward/Flux/(2, 128)
1802.16
ns1784.6923076923076
ns1.01
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128)
180.28067700987307
ns180.12934631432546
ns1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128)
17292
ns17322
ns1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128)
16891
ns17012
ns0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128)
39424
ns39274
ns1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128)
29575
ns29480
ns1.00
Dense(20 => 20)/cpu/reverse/Flux/(20, 128)
20247
ns21670.5
ns0.93
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128)
17543
ns17322
ns1.01
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128)
4343.857142857143
ns4363.857142857143
ns1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128)
3877.25
ns3877.125
ns1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128)
3976.125
ns3962.375
ns1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128)
4979.285714285715
ns4940.571428571428
ns1.01
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128)
1666.2
ns1673.2
ns1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128)
38674535
ns49957369
ns0.77
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128)
57855464
ns57581648
ns1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128)
76267517
ns112009767.5
ns0.68
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128)
83618661.5
ns106965049
ns0.78
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128)
72679934
ns105768801
ns0.69
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128)
11680641.5
ns11747527
ns0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128)
17856952
ns17741547
ns1.01
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128)
7013553
ns7001933.5
ns1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128)
6977149.5
ns7000243.5
ns1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128)
10041170
ns18531441
ns0.54
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128)
6389008
ns6394361
ns1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16)
746726487
ns727979542
ns1.03
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64)
2585135934
ns2579077526
ns1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2)
145337500
ns131352834.5
ns1.11
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16)
855270478
ns969586753
ns0.88
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64)
3032822955
ns3263278361
ns0.93
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2)
191513330
ns226498384
ns0.85
vgg16/cpu/reverse/Flux/(32, 32, 3, 16)
628663470
ns873641593
ns0.72
vgg16/cpu/reverse/Flux/(32, 32, 3, 64)
2431719966
ns3046764407
ns0.80
vgg16/cpu/reverse/Flux/(32, 32, 3, 2)
128430302
ns131007272.5
ns0.98
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16)
174231960
ns174047082
ns1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64)
644942314.5
ns645727264.5
ns1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2)
45780988.5
ns45534194
ns1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16)
164658201
ns164324255
ns1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64)
633987268
ns642792641
ns0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2)
30500114
ns29977065
ns1.02
vgg16/cpu/forward/Flux/(32, 32, 3, 16)
188572145
ns201738658
ns0.93
vgg16/cpu/forward/Flux/(32, 32, 3, 64)
724357861.5
ns918807028
ns0.79
vgg16/cpu/forward/Flux/(32, 32, 3, 2)
35976345
ns40275483
ns0.89
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128)
1237030180
ns1293544406.5
ns0.96
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128)
1865644001
ns1861728870
ns1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128)
2355646541
ns2468521111
ns0.95
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128)
2519641730
ns2631348316
ns0.96
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128)
1871255667
ns1895453908
ns0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128)
559261598
ns565985350
ns0.99
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128)
321769270
ns321376035
ns1.00
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128)
322333048
ns320502632
ns1.01
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128)
355818057.5
ns486344616
ns0.73
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128)
11750378
ns11994505.5
ns0.98
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128)
17892421
ns17906521
ns1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128)
19139788
ns19110550.5
ns1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128)
23859946.5
ns23813093
ns1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128)
17922065
ns17831210
ns1.01
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128)
1162839
ns1157282
ns1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128)
5778842
ns5758236.5
ns1.00
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128)
2053111
ns2047310.5
ns1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128)
2034766
ns2030949
ns1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128)
2073890
ns2077663
ns1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128)
197649
ns201106
ns0.98
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128)
295075.5
ns293318
ns1.01
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128)
267590
ns265260.5
ns1.01
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128)
372184.5
ns366414
ns1.02
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128)
411758
ns408473
ns1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128)
275684
ns273792
ns1.01
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128)
410075
ns406058
ns1.01
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128)
83125
ns83516
ns1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128)
81833
ns81513
ns1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128)
83385
ns81592
ns1.02
Dense(200 => 200)/cpu/forward/Flux/(200, 128)
87153
ns86632
ns1.01
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128)
104926
ns104725
ns1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128)
195097943
ns199636923
ns0.98
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128)
326683151.5
ns324677516.5
ns1.01
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128)
383270904
ns437484335
ns0.88
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128)
456636762
ns526587672
ns0.87
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128)
370731185
ns410366601.5
ns0.90
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128)
314203441.5
ns323744332.5
ns0.97
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128)
101771532
ns102515100.5
ns0.99
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128)
44354247.5
ns43799747
ns1.01
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128)
44239473.5
ns43534417
ns1.02
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128)
60347251
ns63896063
ns0.94
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128)
28310724.5
ns28176419
ns1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128)
18982153.5
ns18931882
ns1.00
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128)
19570435.5
ns19487991
ns1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128)
23634115
ns23372438.5
ns1.01
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128)
24235439
ns24063046.5
ns1.01
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128)
19730324
ns19611168.5
ns1.01
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128)
6543374
ns6520318
ns1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128)
6559820
ns6498792
ns1.01
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128)
6519284
ns6477998.5
ns1.01
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128)
6512090
ns6511497
ns1.00
This comment was automatically generated by workflow using github-action-benchmark.