Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training benchmarking script #264

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions benchmarking/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
Colors = "0.12"
CUDA = "5"
DataFrames = "1"
Flux = "0.14"
GLMakie = "0.20"
MLDatasets = "0.7"
Metalhead = "0.9"
Optimisers = "0.3"
ProgressMeter = "1.9"
TimerOutputs = "0.5"
cuDNN = "1.2"
105 changes: 105 additions & 0 deletions benchmarking/benchmark.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@

using Colors
using CUDA, cuDNN
using DataFrames
using Flux
using Flux: logitcrossentropy, onecold, onehotbatch
using GLMakie
using Metalhead
using MLDatasets
using Optimisers
using ProgressMeter
using TimerOutputs

include("tooling.jl")

function run()

epochs = 45
batchsize = 1000
device = gpu
CUDA.allowscalar(false)
allow_skips = true

train_loader, test_loader, labels = load_cifar10(; batchsize)
nlabels = length(labels)
firstbatch = first(first(train_loader))
imsize = size(firstbatch)[1:2]

@info "Benchmarking" epochs batchsize device imsize

to = TimerOutput()

common = "pretrain=false, inchannels=3, nclasses=$(length(labels))"

# these should all be the smallest variant of each that is tested in `/test`
modelstrings = (
"AlexNet(; $common)",
"VGG(11, batchnorm=true; $common)",
"SqueezeNet(; $common)",
"ResNet(18; $common)",
"WideResNet(50; $common)",
"ResNeXt(50, cardinality=32, base_width=4; $common)",
"SEResNet(18; $common)",
"SEResNeXt(50, cardinality=32, base_width=4; $common)",
"Res2Net(50, base_width=26, scale=4; $common)",
"Res2NeXt(50; $common)",
"GoogLeNet(batchnorm=true; $common)",
"DenseNet(121; $common)",
"Inceptionv3(; $common)",
"Inceptionv4(; $common)",
"InceptionResNetv2(; $common)",
"Xception(; $common)",
"MobileNetv1(0.5; $common)",
"MobileNetv2(0.5; $common)",
"MobileNetv3(:small, width_mult=0.5; $common)",
"MNASNet(:A1, width_mult=0.5; $common)",
"EfficientNet(:b0; $common)",
"EfficientNetv2(:small; $common)",
"ConvMixer(:small; $common)",
"ConvNeXt(:small; $common)",
# "MLPMixer(; $common)", # no tests found
# "ResMLP(; $common)", # no tests found
# "gMLP(; $common)", # no tests found
"ViT(:tiny; $common)",
# "UNet(; $common)" # doesn't support kwargs "inchannels", "nclasses"
)
df = DataFrame(; model=String[], train_loss=Float64[], train_acc=Float64[], test_loss=Float64[], test_acc=Float64[])
cols = distinguishable_colors(length(modelstrings), [RGB(1,1,1), RGB(0,0,0)], dropseed=true)
f = Figure()
ax = Axis(f[1, 1], title="CIFAR-10 Training on a Nvidia 3090, batch 1000\nTest accuracy vs. time over 45 epochs", xlabel="Time (s)", ylabel="Testset Accuracy")
display(f)
max_x = 0
for (i, modstring) in enumerate(modelstrings)
@timeit to "$modstring" begin
@info "Evaluating $i/$(length(modelstrings)): $modstring"
# Initial precompile is variable based on what came before, so don't time first load
eval(Meta.parse(modstring))
# second load simulates what might be possible with a proper set-up pkgimage workload
@timeit to "Load" model=eval(Meta.parse(modstring))
@timeit to "Training" ret = train(model,
train_loader,
test_loader;
limit = 1,
to,
device
)
elapsed, train_loss, train_acc, test_loss, test_acc = if isnothing(ret)
allow_skips || break
missing, missing, missing, missing, missing
else
elapsed, train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist = ret
max_x = max(maximum(elapsed), max_x)
lines!(ax, elapsed, test_acc_hist, label=modstring, color=cols[i])
elapsed, train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end]
end
push!(df, (modstring, train_loss, train_acc, test_loss, test_acc), promote=true)
end
GC.gc(true)
end
f[1, 2] = Legend(f, ax, "Models", framevisible = false)
display(f)
display(df)
print_timer(to; sortby = :firstexec)
end
run()
104 changes: 104 additions & 0 deletions benchmarking/tooling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
function loss_and_accuracy(data_loader, model, device; limit = nothing)
acc = 0
ls = 0.0f0
num = 0
i = 0
for (x, y) in data_loader
x, y = x |> device, y |> device
ŷ = model(x)
ls += logitcrossentropy(ŷ, y, agg=sum)
acc += sum(onecold(ŷ) .== onecold(y))
num += size(x)[end]
if limit !== nothing
i == limit && break
i += 1
end
end
return ls / num, acc / num
end

function load_cifar10(; batchsize=1000)
@info "loading CIFAR-10 dataset"
train_dataset, test_dataset = CIFAR10(split=:train), CIFAR10(split=:test)
train_x, train_y = train_dataset[:]
test_x, test_y = test_dataset[:]
@assert train_dataset.metadata["class_names"] == test_dataset.metadata["class_names"]
labels = train_dataset.metadata["class_names"]

# CIFAR10 label indices seem to be zero-indexed
train_y .+= 1
test_y .+= 1

train_y_ohb = Flux.onehotbatch(train_y, eachindex(labels))
test_y_ohb = Flux.onehotbatch(test_y, eachindex(labels))

train_loader = Flux.DataLoader((data=train_x, labels=train_y_ohb); batchsize, shuffle=true)
test_loader = Flux.DataLoader((data=test_x, labels=test_y_ohb); batchsize)

return train_loader, test_loader, labels
end

function _train(model, train_loader, test_loader; epochs = 45,
device = gpu, limit=nothing, gpu_gc=true, gpu_stats=false, show_plots=false, to=TimerOutput())

model = model |> device

opt = Optimisers.Adam()
state = Optimisers.setup(opt, model)

train_loss_hist, train_acc_hist = Float64[], Float64[]
test_loss_hist, test_acc_hist = Float64[], Float64[]
elapsed = Float64[]
@showprogress "training" for epoch in 1:epochs
i = 0
for (x, y) in train_loader
x, y = x |> device, y |> device
@timeit to "batch step" begin
gs, _ = gradient(model, x) do m, _x
logitcrossentropy(m(_x), y)
end
state, model = Optimisers.update!(state, model, gs)
end

device === gpu && gpu_stats && CUDA.memory_status()
if limit !== nothing
i == limit && break
i += 1
end
end

train_loss, train_acc = loss_and_accuracy(train_loader, model, device; limit)
@timeit to "testing" test_loss, test_acc = loss_and_accuracy(test_loader, model, device; limit)
push!(train_loss_hist, train_loss); push!(train_acc_hist, train_acc);
push!(test_loss_hist, test_loss); push!(test_acc_hist, test_acc);
push!(elapsed, time())
if show_plots
plt2 = lineplot(1:epoch, train_loss_hist, name = "train_loss", xlabel="epoch", ylabel="loss")
lineplot!(plt2, 1:epoch, test_loss_hist, name = "test_loss")
display(plt2)
plt2 = lineplot(1:epoch, train_acc_hist, name = "train_acc", xlabel="epoch", ylabel="acc")
lineplot!(plt2, 1:epoch, test_acc_hist, name = "test_acc")
display(plt2)
end
if device === gpu && gpu_gc
GC.gc() # GPU will OOM without this
end
end
elapsed = elapsed .- elapsed[1]
train_loss, train_acc, test_loss, test_acc = train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end]
@info "results after $epochs epochs $(repr(map(x->round(x, digits=3), (; train_loss, train_acc, test_loss, test_acc))))"
return elapsed, train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist
end

# because Flux stacktraces are ludicrously big on <1.10 so don't show them
function train(args...;kwargs...)
try
_train(args...; kwargs...)
catch ex
# rethrow()
println()
println(sprint(showerror, ex))
GC.gc()
return nothing
end
end
Loading