-
Notifications
You must be signed in to change notification settings - Fork 723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trained policy export to ONNX via PyTorch #922
Comments
This should be applicable.
I recommend you to learn more on how it works for continuous actions ;) See resources in the doc: https://stable-baselines.readthedocs.io/en/master/guide/rl.html (especially Spinning Up) In the case of continuous actions, a Gaussian distribution is usually used, so the network will output a mean (the deterministic action) and standard deviation that will be used to sample actions. |
Pinging @pstansell the member of our project who has the deeper understanding of RL (I am more of a code monkey in this area). |
Best for you now would be to use Stable-Baselines3 (directly in PyTorch). |
We did actually get this to work, and without pytorch. In our case we wanted to export directly to a mat file, we ended up with something like this:
|
I am attempting to export a trained policy to the ONNX common interchange format for use in prediction only. As such I found a very useful discussion in issue #372 . This issue describes how to convert a model to an equivalent pytorch model. In turn pytorch has support for export to ONNX. Using the code in the collab notebook linked to in #372 I was able to create a script which trained a cartpole model, converted it to pytorch, then exported to ONNX, as shown below.
I used the
PyTorchCnnPolicy
class andcopy_cnn_weights
function from the linked notebook. The ONNX network produced can be visualised with Netron, and is shown below:The results of this were good, or at least produced something that had the same numbers in it as the stable baselines policy. However, what I would like to know, is, how applicable is this to other environments trained using PPO1 and PPO2 with the MlpPolicy. Particularly for the case where the action space is not discrete, but rather continuous. Are modifications required in this case?
The text was updated successfully, but these errors were encountered: