Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Severe performance degradation of SliceSampler for large buffers #2669

Closed
3 tasks done
nicklashansen opened this issue Dec 19, 2024 · 12 comments
Closed
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@nicklashansen
Copy link

nicklashansen commented Dec 19, 2024

Describe the bug

Replay buffer SliceSampler speed is heavily dependent on a small buffer capacity. Increasing buffer capacity renders the replay buffer practically unusable. I have compiled some timings on sampling speed for the TD-MPC2 codebase with a varying number of transitions stored + total buffer capacity. These timings are all for low-dimensional observations/actions and a relatively small batch size of 256 sequences each of length 4. This issue seems to be directly related to the SliceSampler class since I did not encounter this issue before the introduction of the SliceSampler.

Precise sample() timings:

5M transitions, 6M capacity: 0.0048s
5M transitions, 600M capacity: 0.0361s
346M transitions, 600M capacity: 1.6531s

To Reproduce

The replay buffer implementation is available here and mostly relies on torchrl features.

I'm also happy to create a minimal example that does not depend on the TD-MPC2 codebase if you think that would helpful.

System info

I use the conda environment available here: https://github.com/nicklashansen/tdmpc2/blob/main/docker/environment.yaml

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@nicklashansen nicklashansen added the bug Something isn't working label Dec 19, 2024
@vmoens
Copy link
Contributor

vmoens commented Dec 19, 2024

Oh yeah that's pretty bad, let me see what I can do

Quick question: where do you store your data? By that I mean the trajectory count / done states.
I suspect that storing them on cuda and compiling should help (?)

@vmoens
Copy link
Contributor

vmoens commented Dec 19, 2024

Found it
https://github.com/nicklashansen/tdmpc2/blob/df8a465c8e137c652a142f6ad6cdf540d3a6a39a/tdmpc2/common/buffer.py#L61

@nicklashansen
Copy link
Author

That would be appreciated! Worst case, a possible workaround would be to have two replay buffer implementations and select the most appropriate implementation based on capacity, but I would be interested in better solutions.

@nicklashansen
Copy link
Author

Data is stored in RAM for these larger buffers since they are in the 50-100GB range

@vmoens
Copy link
Contributor

vmoens commented Dec 19, 2024

Ok I'll create a sandbox to check this carefully and post it here for reference

@vmoens
Copy link
Contributor

vmoens commented Dec 19, 2024

https://gist.github.com/vmoens/6a860ba376ce99737dfdf5637c7eaee7

Can you let me know how to edit the fake data to make it more suited?

RE caching: caching is useful if you're sampling more than once after every extension. This can drastically speed things up because you don't need to recompute the indices of the trajectories.

@vmoens
Copy link
Contributor

vmoens commented Dec 19, 2024

Throwing another datapoint here:
On my cluster, if I run the code using the trajectory indicator or the end signal (the done state) the latter is 2x faster (500ms vs 1 sec for 600M capacity filled at 50%) than the former, presumably because we can work with bits and not integers to identify the trajectories

@nicklashansen
Copy link
Author

Oh I was not aware of the caching argument. I tried it just now with the 346M transitions, 600M capacity buffer and enabling caching speeds up sampling by 500x on my machine. That's really good to know!

if I run the code using the trajectory indicator or the end signal (the done state) the latter is 2x faster (500ms vs 1 sec for 600M capacity filled at 50%) than the former, presumably because we can work with bits and not integers to identify the trajectories

That makes a lot of sense to me, I can try rewrite my code a bit to operate with a done signal rather than episode ID.

Thanks a lot for your help!

@nicklashansen
Copy link
Author

Can you let me know how to edit the fake data to make it more suited?

Regarding this. The data used in offline TD-MPC2 training has this structure:

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([345690000, 6]), device=cpu, dtype=torch.float32, is_shared=False),
        episode: Tensor(shape=torch.Size([345690000]), device=cpu, dtype=torch.int64, is_shared=False),
        obs: Tensor(shape=torch.Size([345690000, 24]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([345690000]), device=cpu, dtype=torch.float32, is_shared=False),
        task: Tensor(shape=torch.Size([345690000]), device=cpu, dtype=torch.int32, is_shared=False)},
    batch_size=torch.Size([345690000]),
    device=cpu,
    is_shared=False)

@vmoens
Copy link
Contributor

vmoens commented Dec 19, 2024

Gotcha

I will also land #2672 and #2671 once the tests pass (and I document the features a bit more!)

Together they give me a speedup of about 2-3x when cache is disabled

@nicklashansen
Copy link
Author

That's very impressive! Let me know when the new features are ready and I'll be more than happy to give them a try with the tdmpc2 repo.

@vmoens
Copy link
Contributor

vmoens commented Dec 20, 2024

Closing this issue thanks to #2670, #2671 and #2672

@vmoens vmoens closed this as completed Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants