Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch struct #84

Merged
merged 40 commits into from
Oct 25, 2023
Merged

Batch struct #84

merged 40 commits into from
Oct 25, 2023

Conversation

benedict-96
Copy link
Collaborator

Introduced a new struct Batch that takes care of drawing new batches for a new epoch during training. This was previously (sloppily) done with DataLoader. Some helper functions have also been added, most importantly optimize_for_one_epoch(::opt::Optimizer, model, ps, dl::DataLoader, batch::Batch, loss). I also tried to add and improve documentation for routines I changed/introduced.

…onality in there. The DataLoader should purely handle loading data (as the name suggests) and not drawing batches! the new function optimize_for_one_epoch optimizes for an entire epoch and outputs the average loss.
@codecov
Copy link

codecov bot commented Oct 20, 2023

Codecov Report

Merging #84 (473c8ae) into main (15ea1f6) will decrease coverage by 1.37%.
The diff coverage is 76.92%.

@@            Coverage Diff             @@
##             main      #84      +/-   ##
==========================================
- Coverage   70.71%   69.35%   -1.37%     
==========================================
  Files          95       96       +1     
  Lines        2363     2382      +19     
==========================================
- Hits         1671     1652      -19     
- Misses        692      730      +38     
Files Coverage Δ
src/GeometricMachineLearning.jl 100.00% <ø> (ø)
src/data_loader/mnist_utils.jl 83.33% <100.00%> (ø)
src/layers/classification.jl 100.00% <ø> (ø)
src/layers/transformer.jl 100.00% <ø> (ø)
src/optimizers/optimizer.jl 100.00% <100.00%> (ø)
src/data_loader/batch.jl 93.10% <93.10%> (ø)
src/data_loader/tensor_assign.jl 35.29% <80.00%> (-50.01%) ⬇️
src/data_loader/data_loader.jl 55.26% <54.83%> (-9.33%) ⬇️

... and 1 file with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@michakraus
Copy link
Member

There is some Julia v1.8 issue that needs to be fixed, plus a rebase.

@benedict-96
Copy link
Collaborator Author

benedict-96 commented Oct 25, 2023

Tests for v1.8 go through now. I replaced the outer by an inner constructor and now it works:

struct Batch{seq_length}
    batch_size::Integer
    seq_length::Union{Nothing, Integer}
end

function Batch(batch_size, seq_length)
    Batch{true}(batch_size, seq_length)
end

function Batch(batch_size::Integer)
    Batch{false}(batch_size, nothing)
end

to

struct Batch{seq_length}
    batch_size::Integer
    seq_length::Union{Nothing, Integer}

    function Batch(batch_size, seq_length)
        new{true}(batch_size, seq_length)
    end

    function Batch(batch_size::Integer)
        new{false}(batch_size, nothing)
    end
end

@michakraus
Copy link
Member

michakraus commented Oct 25, 2023

I see two issues here:

  1. You use seq_length as both a type parameter and a field name.
  2. Your fields do not have concrete types.

I would rather write this along the following lines:

struct Batch{seq_type <: Union{Nothing, Integer}}
    batch_size::Int
    seq_length::seq_type

    function Batch(batch_size, seq_length = nothing)
        new{typeof(seq_length)}(batch_size, seq_length)
    end
end

hasseqlength(::Batch{<:Integer}) = true
hasseqlength(::Batch{<:Nothing}) = false

@benedict-96
Copy link
Collaborator Author

benedict-96 commented Oct 25, 2023

Ok, I see, that's why for v1.8 the error was that function argument and static parameter names must be distinct. Copy-pasted the suggested change and pushed again.

@michakraus
Copy link
Member

I guess you were a bit too quick and forgot to run your tests 😉 .

There are a few more changes required, e.g.,

function (batch::Batch{false})

needs to become

function (batch::Batch{Nothing})

etc.

@benedict-96
Copy link
Collaborator Author

sorry 😅 . Should be working now.

@michakraus michakraus merged commit c62b834 into main Oct 25, 2023
@michakraus michakraus deleted the transformer_description branch October 25, 2023 07:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants