Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Fix different batches per epoch in NoDuplicatesBatchSampler #3073

Merged
merged 2 commits into from
Nov 20, 2024

Conversation

tomaarsen
Copy link
Collaborator

Resolves #3069

Hello!

Pull Request overview

  • Ensure that different epochs have different batches & sample orders when using NoDuplicatesBatchSampler

Details

As described in #3069, the set in

remaining_indices = set(torch.randperm(len(self.dataset), generator=self.generator).tolist())
will always "undo" the torch.randperm results. As a result, we want a different data structure here. One that:

  1. Allows for cheap removal of elements
  2. Preserves the order of elements, i.e. remains random

We brainstormed a bit in #3069 and think that a dict should work. Here's a benchmarking script:

import time
from datasets import load_dataset
import torch
from sentence_transformers.sampler import NoDuplicatesBatchSampler

dataset = load_dataset("sentence-transformers/gooaq", split="train")

sampler = NoDuplicatesBatchSampler(dataset, batch_size=16, drop_last=False, generator=torch.Generator(), seed=12)
for epoch in range(5):
    sampler.set_epoch(epoch)
    start_t = time.time()
    iterator = iter(sampler)
    first_batch = next(iterator)
    for _ in range(999):
        batch = next(iterator)
    print(f"Time: {time.time() - start_t}")
    print(f"Epoch {epoch}: {first_batch}")

Baseline

Time for sampling 1 batch:

Time: 0.2091832160949707
Time: 0.2834758758544922
Time: 0.2873103618621826
Time: 0.2814059257507324
Time: 0.2850480079650879

Time for sampling 1k batches:

Time: 0.5647664070129395
Time: 0.567711591720581
Time: 0.5873234272003174
Time: 0.5906450748443604
Time: 0.5707104206085205

The first batch of the first 5 epochs

Epoch 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Epoch 1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Epoch 2: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Epoch 3: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Epoch 4: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

This PR

Time for sampling 1 batch:

Time: 0.32395291328430176
Time: 0.3055605888366699
Time: 0.3509399890899658
Time: 0.31282591819763184
Time: 0.3069179058074951

Time for sampling 1k batches:

Time: 0.6388254165649414
Time: 0.6368660926818848
Time: 0.6400094032287598
Time: 0.6855106353759766
Time: 0.6412439346313477

The first batch of the first 5 epochs

Epoch 0: [2387739, 560984, 308054, 615495, 1244897, 695998, 2900825, 1969931, 2349236, 2061716, 2644210, 1275724, 511408, 636195, 1957309, 1979379]
Epoch 1: [2360850, 2702978, 2008784, 937054, 571222, 430388, 2675002, 1273641, 1981604, 1595060, 1483721, 1990202, 981360, 2722024, 2068398, 167249]
Epoch 2: [2222763, 1083739, 950218, 2017130, 2968363, 650640, 1423998, 1324413, 1798492, 2944449, 833372, 1511964, 1693704, 1242533, 2366878, 404897]
Epoch 3: [524072, 2819725, 166172, 318777, 1526367, 709051, 1363460, 336155, 1942671, 405301, 1906733, 1735506, 468715, 1113390, 290235, 1249274]
Epoch 4: [1054153, 2167344, 2661437, 1187269, 438521, 2968, 47040, 1363342, 899735, 1113497, 154815, 1283876, 2809544, 513471, 1312820, 1280375]

List

If we use a list and use the expensive del list[index], then we get this:

Time for sampling 1k batches:

Time: 14.504975080490112
Time: 10.611507415771484
Time: 11.411868572235107
Time: 11.611666917800903
...

Deleting while iterating

We can also delete indices while iterating over remaining_indices, but then we have to instead iterate over list(remaining_indices.keys()), and then we get this:

Time for sampling 1k batches:

Time: 22.72933268547058
Time: 21.940544605255127
Time: 21.886675357818604
...

Metrics

So, we're looking at a 10% slowdown and the dictionary will use ~25% more memory than the set. However, I think this might be acceptable given that the alternative (i.e. keeping it as-is) is pretty awful: it can really hamstring the performance of multi-epoch training runs.

Note

This also relies on huggingface/accelerate#3246 to update the epoch - otherwise each batch will still be the same across epochs.

cc @antigregory curious about your thoughts on the code & performance hit

  • Tom Aarsen

@tomaarsen
Copy link
Collaborator Author

With the dict.fromkeys (Thanks @Wauplin for the suggestion), the performance hit is reduced by a bit.

This PR (new)

Time for sampling 1 batch:

Time: 0.23999476432800293
Time: 0.2646820545196533
Time: 0.2566261291503906
Time: 0.2643446922302246
Time: 0.2761116027832031

(down from ~0.31s avg)

Time for sampling 1k batches:

Time: 0.604529619216919
Time: 0.6064455509185791
Time: 0.6098651885986328
Time: 0.6006977558135986
Time: 0.5981855392456055

(down from ~0.635s avg)

The 1 batch case actually seems faster than the baseline - awesome!

  • Tom Aarsen

@antigregory
Copy link

Thanks for the PR @tomaarsen!
As you said, the performance & memory overheads seem very reasonable.

However, I'm not sure to have understood what alternative you were benchmarking in the paragraph Deleting while iterating (but it's not very important given the poor results of this alternative).

@tomaarsen
Copy link
Collaborator Author

However, I'm not sure to have understood what alternative you were benchmarking in the paragraph Deleting while iterating (but it's not very important given the poor results of this alternative).

I didn't explain it very well indeed, apologies. I'll clarify (regardless of the horrible results, hah!)
In this PR, I'm performing

            for index in batch_indices:
                del remaining_indices[index]

after that batch has been sampled.

I can also do this del remaining_indices[index] during the iteration itself. E.g.:

                batch_indices.append(index)
                del remaining_indices[index]
                if len(batch_indices) == self.batch_size:
                    yield batch_indices
                    break

The idea being that you then don't need to iterate over the batch_indices again.

But dictionaries don't like it if you update them while you're iterating over them. So, you have to iterate over a copy or you'll get an error. So, instead of

            for index in remaining_indices:

I tried

            for index in list(remaining_indices.keys()):

Needless to say - that didn't help.

  • Tom Aarsen

@tomaarsen tomaarsen requested a review from Copilot November 20, 2024 09:42

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 1 out of 1 changed files in this pull request and generated no suggestions.

Comments skipped due to low confidence (1)

sentence_transformers/sampler.py:215

  • Ensure that the new behavior of deleting elements from the dictionary is covered by tests.
for index in batch_indices:
    del remaining_indices[index]
@antigregory
Copy link

Thanks a lot for a very clear explanation @tomaarsen 👍

@tomaarsen
Copy link
Collaborator Author

I'll merge this, but it won't be fully fixed until huggingface/accelerate#3246 is also merged and released.

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 8fabce0 into UKPLab:master Nov 20, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Same order of training samples with NoDuplicatesBatchSampler
2 participants