Skip to content

Commit

Permalink
docs: update mnist node example
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 27, 2024
1 parent a7dbec9 commit 14862e6
Showing 1 changed file with 52 additions and 63 deletions.
115 changes: 52 additions & 63 deletions docs/src/examples/mnist_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ Training a classifier for **MNIST** using a neural ordinary differential equatio

(Step-by-step description below)

```@example mnist
using DiffEqFlux, CUDA, Zygote, NNlib, OrdinaryDiffEq, Test, Lux, Statistics,
ComponentArrays, Random, Optimization, OptimizationOptimisers, LuxCUDA,
MLUtils, OneHotArrays
```@example mnist_full
using DiffEqFlux, CUDA, Zygote, NNlib, OrdinaryDiffEq, Lux, Statistics, ComponentArrays,
Random, Optimization, OptimizationOptimisers, LuxCUDA, MLUtils, OneHotArrays
using MLDatasets: MNIST
CUDA.allowscalar(false)
Expand All @@ -17,9 +16,9 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true
const cdev = cpu_device()
const gdev = gpu_device()
logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1))
logitcrossentropy = CrossEntropyLoss(; logits = Val(true))
function loadmnist(batchsize = bs)
function loadmnist(batchsize)
# Load MNIST
dataset = MNIST(; split = :train)
imgs = dataset.features
Expand All @@ -29,33 +28,31 @@ function loadmnist(batchsize = bs)
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
return DataLoader((x_data, y_data); batchsize, shuffle = true)
return DataLoader(mapobs(gdev, (x_data, y_data)); batchsize, shuffle = true)
end
# Main
const bs = 32
dataloader = loadmnist(bs)
dataloader = loadmnist(128)
down = Lux.Chain(Lux.FlattenLayer(), Lux.Dense(784, 20, tanh))
nn = Lux.Chain(Lux.Dense(20, 10, tanh), Lux.Dense(10, 10, tanh), Lux.Dense(10, 20, tanh))
fc = Lux.Dense(20, 10)
down = Chain(FlattenLayer(), Dense(784, 20, tanh))
nn = Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh))
fc = Dense(20, 10)
nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false,
reltol = 1e-3, abstol = 1e-3, save_start = false)
DiffEqArray_to_Array(x) = x.u[end]
solution_to_array(sol) = sol.u[end]
# Build our over-all model topology
m = Lux.Chain(; down, nn_ode, convert = Lux.WrappedFunction(DiffEqArray_to_Array), fc)
ps, st = Lux.setup(Xoshiro(0), m)
ps = ComponentArray(ps) |> gdev
st = st |> gdev
m = Chain(; down, nn_ode, convert = WrappedFunction(solution_to_array), fc)
ps, st = Lux.setup(Xoshiro(0), m);
ps = ComponentArray(ps) |> gdev;
st = st |> gdev;
# We can also build the model topology without a NN-ODE
m_no_ode = Lux.Chain(; down, nn, fc)
ps_no_ode, st_no_ode = Lux.setup(Xoshiro(0), m_no_ode)
ps_no_ode = ComponentArray(ps_no_ode) |> gdev
st_no_ode = st_no_ode |> gdev
m_no_ode = Chain(; down, nn, fc)
ps_no_ode, st_no_ode = Lux.setup(Xoshiro(0), m_no_ode);
ps_no_ode = ComponentArray(ps_no_ode) |> gdev;
st_no_ode = st_no_ode |> gdev;
x_train1, y_train1 = first(dataloader)
Expand All @@ -73,7 +70,7 @@ function accuracy(model, data, ps, st; n_batches = 100)
total_correct = 0
total = 0
st = Lux.testmode(st)
for (x, y) in collect(data)[1:n_batches]
for (x, y) in collect(data)[1:min(n_batches, length(data))]
target_class = classify(cdev(y))
predicted_class = classify(cdev(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
Expand All @@ -100,27 +97,23 @@ opt_prob = OptimizationProblem(opt_func, ps)
function callback(ps, l, pred)
global iter += 1
# Monitor that the weights do infact update
# Every 10 training iterations show accuracy
if (iter % 10 == 0)
@info "[MNIST GPU] Accuracy: $(accuracy(m, zip(x_train, y_train), ps, st))"
end
iter % 10 == 0 &&
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))"
return false
end
# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, zip(x_train, y_train); callback)
@test accuracy(m, zip(x_train, y_train), res.u, st) > 0.8
res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5)
@assert accuracy(m, dataloader, res.u, st) > 0.8
```

## Step-by-Step Description

### Load Packages

```@example mnist
using DiffEqFlux, CUDA, Zygote, NNlib, OrdinaryDiffEq, Test, Lux, Statistics,
ComponentArrays, Random, Optimization, OptimizationOptimisers, LuxCUDA,
MLUtils, OneHotArrays
using DiffEqFlux, CUDA, Zygote, NNlib, OrdinaryDiffEq, Lux, Statistics, ComponentArrays,
Random, Optimization, OptimizationOptimisers, LuxCUDA, MLUtils, OneHotArrays
using MLDatasets: MNIST
```

Expand All @@ -131,6 +124,7 @@ A good trick used here:
```@example mnist
CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
const cdev = cpu_device()
const gdev = gpu_device()
```
Expand All @@ -146,16 +140,16 @@ The MNIST dataset is split into 60,000 train and 10,000 test images, ensuring a

The preprocessing is done in `loadmnist` where the raw MNIST data is split into features `x` and labels `y`.
Features are reshaped into format **[Height, Width, Color, Samples]**, in case of the train set **[28, 28, 1, 60000]**.
Using MLDataUtils's `LabelEnc` function, the labels (numbers 0 to 9) are one-hot encoded, resulting in a a **[10, 60000]** `OneHotMatrix`.
Using OneHotArrays's `onehotbatch` function, the labels (numbers 0 to 9) are one-hot encoded, resulting in a a **[10, 60000]** `OneHotMatrix`.

Features and labels are then passed to MLDataUtils's `batchview`.
This automatically minibatches both the images and labels using the specified `batchsize`,
meaning that every minibatch will contain 128 images with a single color channel of 28x28 pixels.
Features and labels are then passed to MLUtils's `DataLoader`. This automatically
minibatches both the images and labels using the specified `batchsize`, meaning that every
minibatch will contain 128 images with a single color channel of 28x28 pixels.

```@example mnist
logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1))
logitcrossentropy = CrossEntropyLoss(; logits = Val(true))
function loadmnist(batchsize = bs)
function loadmnist(batchsize)
# Load MNIST
dataset = MNIST(; split = :train)
imgs = dataset.features
Expand All @@ -165,16 +159,14 @@ function loadmnist(batchsize = bs)
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
return DataLoader((x_data, y_data); batchsize, shuffle = true)
return DataLoader(mapobs(gdev, (x_data, y_data)); batchsize, shuffle = true)
end
```

and then loaded from main:

```@example mnist
# Main
const bs = 32
dataloader = loadmnist(bs)
dataloader = loadmnist(128)
```

### Layers
Expand All @@ -184,9 +176,9 @@ The Neural Network requires passing inputs sequentially through multiple layers.
to the next. Four different sets of layers are used here:

```@example mnist
down = Lux.Chain(Lux.FlattenLayer(), Lux.Dense(784, 20, tanh))
nn = Lux.Chain(Lux.Dense(20, 10, tanh), Lux.Dense(10, 10, tanh), Lux.Dense(10, 20, tanh))
fc = Lux.Dense(20, 10)
down = Chain(FlattenLayer(), Dense(784, 20, tanh))
nn = Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh))
fc = Dense(20, 10)
```

`down`: This layer downsamples our images into a 20 dimensional feature vector.
Expand All @@ -210,7 +202,7 @@ a Matrix (CuArray), and reduces the matrix from 3 to 2 dimensions for use in the
nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false,
reltol = 1e-3, abstol = 1e-3, save_start = false)
DiffEqArray_to_Array(x) = x.u[end]
solution_to_array(sol) = sol.u[end]
```

For CPU: If this function does not automatically fall back to CPU when no GPU is present, we can
Expand All @@ -221,11 +213,11 @@ change `gdev(x)` to `Array(x)`.
Next, we connect all layers together in a single chain:

```@example mnist
# Build our overall model topology
m = Lux.Chain(; down, nn_ode, convert = Lux.WrappedFunction(DiffEqArray_to_Array), fc)
ps, st = Lux.setup(Xoshiro(0), m)
ps = ComponentArray(ps) |> gdev
st = st |> gdev
# Build our over-all model topology
m = Chain(; down, nn_ode, convert = WrappedFunction(solution_to_array), fc)
ps, st = Lux.setup(Xoshiro(0), m);
ps = ComponentArray(ps) |> gdev;
st = st |> gdev;
```

### Prediction
Expand All @@ -246,7 +238,7 @@ function accuracy(model, data, ps, st; n_batches = 100)
total_correct = 0
total = 0
st = Lux.testmode(st)
for (x, y) in collect(data)[1:n_batches]
for (x, y) in collect(data)[1:min(n_batches, length(data))]
target_class = classify(cdev(y))
predicted_class = classify(cdev(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
Expand All @@ -255,7 +247,7 @@ function accuracy(model, data, ps, st; n_batches = 100)
return total_correct / total
end
accuracy(m, zip(x_train, y_train), ps, st)
accuracy(m, ((x_train1, y_train1),), ps, st) # burn in accuracy
```

### Training Parameters
Expand All @@ -275,7 +267,7 @@ function loss_function(ps, x, y)
return logitcrossentropy(pred, y), pred
end
loss_function(ps, x_train1, y_train1)
loss_function(ps, x_train1, y_train1) # burn in loss
```

#### Optimizer
Expand All @@ -300,11 +292,8 @@ opt_prob = OptimizationProblem(opt_func, ps)
function callback(ps, l, pred)
global iter += 1
# Monitor that the weights do infact update
# Every 10 training iterations show accuracy
if (iter % 10 == 0)
@info "[MNIST GPU] Accuracy: $(accuracy(m, zip(x_train, y_train), ps, st))"
end
iter % 10 == 0 &&
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))"
return false
end
```
Expand All @@ -317,8 +306,8 @@ for Neural ODE is given by `nn_ode.p`:

```@example mnist
# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, zip(x_train, y_train); callback)
acc = accuracy(m, zip(x_train, y_train), res.u, st)
@test acc > 0.8 # hide
res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5)
acc = accuracy(m, dataloader, res.u, st)
@assert acc > 0.8 # hide
acc # hide
```

0 comments on commit 14862e6

Please sign in to comment.