-
Notifications
You must be signed in to change notification settings - Fork 723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Converting a model into PyTorch #372
Comments
Hello,
There is no such up to my knowledge. However, I would recommend you first trying with a simpler network (e.g. on CartPole-v1), because it gets tricky with convolution (different conventions for pytorch and tensorflow). EDIT: I made it work with CartPole, will share the notebook soon |
Update: I made it work with CartPole, you can find the notebook here: https://colab.research.google.com/drive/1R-wHO2gLQScx46EIjqj7Sj6MjK-i5Hey Will try to make it work with the cnn if I have some time this weekend. |
Update: I'm working on the CNN now, it seems that the problem comes with the first fully connected layer (the conv layer outputs the right thing). |
The problem comes from the reshape (from conv to fc) |
@p-christ , I solved the issue doing that before flattening: # shape before flattening
# tf: (?, 7, 7, 64)
# pytorch: [1, 64, 7, 7]
x = x.permute(0, 2, 3, 1).contiguous()
x = x.view(x.size(0), -1) |
@p-christ I made a working colab notebook: https://colab.research.google.com/drive/1XwCWeZPnogjz7SLW2kLFXEJGmynQPI-4 The problem came from tensorflow/pytorch differences, not SB. Closing the issue. |
thanks a lot |
Can learnings from this be documented, also if it's possible to make a function to do same in sb3? |
Hi,
I'm trying to load a pre-trained model and convert it into a PyTorch model but can't get it to work and was wondering if someone could help me.
I'm able to load the pre-trained model using stable baselines and copy the weights over to a PyTorch model. But then when I play the game the pytorch agent is not able to get the same score as the baselines agent and i am not sure why. It could potentially be because the baselines agent does some extra pre-processing behind the scenes besides just normalising the state to the 0-1 range? Is anyone able to help me?
I've made a colab to demonstrate the problem here: https://colab.research.google.com/drive/1-IIjA1oKUjg5eoctajpl06OoHzU-5-_9
The text was updated successfully, but these errors were encountered: