[fix
] Fix different batches per epoch in NoDuplicatesBatchSampler
#3073
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Resolves #3069
Hello!
Pull Request overview
NoDuplicatesBatchSampler
Details
As described in #3069, the
set
insentence-transformers/sentence_transformers/sampler.py
Line 187 in 0434450
torch.randperm
results. As a result, we want a different data structure here. One that:We brainstormed a bit in #3069 and think that a
dict
should work. Here's a benchmarking script:Baseline
Time for sampling 1 batch:
Time for sampling 1k batches:
The first batch of the first 5 epochs
This PR
Time for sampling 1 batch:
Time for sampling 1k batches:
The first batch of the first 5 epochs
List
If we use a list and use the expensive
del list[index]
, then we get this:Time for sampling 1k batches:
Deleting while iterating
We can also delete indices while iterating over
remaining_indices
, but then we have to instead iterate overlist(remaining_indices.keys())
, and then we get this:Time for sampling 1k batches:
Metrics
So, we're looking at a 10% slowdown and the dictionary will use ~25% more memory than the set. However, I think this might be acceptable given that the alternative (i.e. keeping it as-is) is pretty awful: it can really hamstring the performance of multi-epoch training runs.
Note
This also relies on huggingface/accelerate#3246 to update the epoch - otherwise each batch will still be the same across epochs.
cc @antigregory curious about your thoughts on the code & performance hit