-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Severe performance degradation of SliceSampler for large buffers #2669
Comments
Oh yeah that's pretty bad, let me see what I can do Quick question: where do you store your data? By that I mean the trajectory count / done states. |
That would be appreciated! Worst case, a possible workaround would be to have two replay buffer implementations and select the most appropriate implementation based on capacity, but I would be interested in better solutions. |
Data is stored in RAM for these larger buffers since they are in the 50-100GB range |
Ok I'll create a sandbox to check this carefully and post it here for reference |
https://gist.github.com/vmoens/6a860ba376ce99737dfdf5637c7eaee7 Can you let me know how to edit the fake data to make it more suited? RE caching: caching is useful if you're sampling more than once after every extension. This can drastically speed things up because you don't need to recompute the indices of the trajectories. |
Throwing another datapoint here: |
Oh I was not aware of the caching argument. I tried it just now with the 346M transitions, 600M capacity buffer and enabling caching speeds up sampling by 500x on my machine. That's really good to know!
That makes a lot of sense to me, I can try rewrite my code a bit to operate with a done signal rather than episode ID. Thanks a lot for your help! |
Regarding this. The data used in offline TD-MPC2 training has this structure:
|
That's very impressive! Let me know when the new features are ready and I'll be more than happy to give them a try with the tdmpc2 repo. |
Describe the bug
Replay buffer
SliceSampler
speed is heavily dependent on a small buffer capacity. Increasing buffer capacity renders the replay buffer practically unusable. I have compiled some timings on sampling speed for the TD-MPC2 codebase with a varying number of transitions stored + total buffer capacity. These timings are all for low-dimensional observations/actions and a relatively small batch size of 256 sequences each of length 4. This issue seems to be directly related to theSliceSampler
class since I did not encounter this issue before the introduction of theSliceSampler
.Precise
sample()
timings:To Reproduce
The replay buffer implementation is available here and mostly relies on
torchrl
features.I'm also happy to create a minimal example that does not depend on the TD-MPC2 codebase if you think that would helpful.
System info
I use the conda environment available here: https://github.com/nicklashansen/tdmpc2/blob/main/docker/environment.yaml
Checklist
The text was updated successfully, but these errors were encountered: