Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Same order of training samples with NoDuplicatesBatchSampler #3069

Closed
antigregory opened this issue Nov 18, 2024 · 4 comments · Fixed by #3073
Closed

Same order of training samples with NoDuplicatesBatchSampler #3069

antigregory opened this issue Nov 18, 2024 · 4 comments · Fixed by #3073
Labels
bug Something isn't working

Comments

@antigregory
Copy link

It seems that NoDuplicatesBatchSampler produces the same set of batches in the same order regardless of the epoch index.
Indeed, in this piece of code, the order of the indices in remaining_indices does not depend on the random permutation torch.randperm(len(self.dataset), generator=self.generator) as it is reset to the ordered range with set.

Moreover, the seed in line 185 does not change from one epoch to another (the set_epoch method does not seem to be used...)

@tomaarsen tomaarsen added the bug Something isn't working label Nov 18, 2024
@tomaarsen
Copy link
Collaborator

Hello!

set_epoch is used via the transformers Trainer which calls set_epoch of the accelerate-wrapped DataLoader, which should propagates it down into the sampler. However, it indeed seems that accelerate only propagates it into the sampler, not the batch sampler: https://github.com/huggingface/accelerate/blob/8ade23cc6aec7c3bd3d80fef6378cafaade75bbe/src/accelerate/data_loader.py#L591-L592
Perhaps this warrants a feature request/pull request on accelerate @muellerzr

As for the set - well spotted. I was falsely under the impression that the insertion order was kept, much like for dict instances. I'd like to avoid converting remaining_indices into a list, as that has an expensive pop. I'm open to suggestions here.

  • Tom Aarsen

@antigregory
Copy link
Author

It might be not the most elegant solution, but maybe remaining_indices could be an OrderedDict?
remaining_indices = OrderedDict({k: None for k in torch.randperm(len(self.dataset), generator=self.generator).tolist()})

It probably won't be as fast as with the set (because the OrderedDict will need to maintain the order after each deletion), but deleting an element from an OrderedDict is still of constant complexity in average.

@tomaarsen
Copy link
Collaborator

I was also considering a dict. Because we're at Python 3.7+ now, I think we can just use a normal dict:

the insertion-order preservation nature of dict objects has been declared to be an official part of the Python language spec.

from https://docs.python.org/3/whatsnew/3.7.html

If the performance hit is not too large, then this is an acceptable solution I think.

I'll also look more into fixing the set_epoch issue in accelerate.

  • Tom Aarsen

@antigregory
Copy link
Author

Yes, I agree. A normal dictionary can also be considered.

Even though, the behavior might be slightly less predictable. Because the declared "insertion-order preservation" does not necessarily mean the preservation of the order after the deletion of some elements.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants