-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
It is recommended to give an example of off policy using the feature extractor #982
Comments
Next time, please help us to help you by taking the necessary time to feel the issue template. As the error suggests, your feature extractor does not take the feature dimension as argument. Then try def __init__(self, observation_space, features_dim): |
The features extractor for off-policy is the same as on-policy and is already documented: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#custom-feature-extractor as @qgallouedec wrote, the error you get is because you pass argument to the features extractor ( |
First of all, thank you very much for your help. I modified the class CustomCNN as follows: NOISE = { class CustomCNN(BaseFeaturesExtractor):
I defined it before: self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_space,) However, there is another error. File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\linear.py", line 96, in init
I'm sorry to ask so many questions. However, I just want to use CNN feature extractor to extract features for mlppolicy policy network. I think this error is stable-baseline3. There is no complete example of off policy algorithm, especially using off policy algorithms such as ddpg, td3, sac. |
Can you try to provide a minimal and functional code example to reproduce the error. (Remove all your |
Sincerely thank you for your help.The problem of off policy network has been bothering me for several days. ··· from typing import Any MODELS = {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO} NOISE = { `class CustomCNN(BaseFeaturesExtractor):
···
·····
····
···
|
I would really like to help you, but you should at least take into consideration the remarks I give you. I need:
import numpy as np
a = np.ones(2)
b = np.ones(2)
c = a / 0 and this code is not functional because the imports are missing: a = np.ones(2)
c = a / 0 Your code is neither minimal nor functional. So I am not able to reproduce your error.
From what I can see, it appears to be a shape-related error. You may have made a mistake in the network specification. |
Sincerely thank you for your help.The problem of off policy network has been bothering me for several days. ··· from typing import Any MODELS = {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO} NOISE = { `class CustomCNN(BaseFeaturesExtractor): def init(self, observation_space: gym.spaces.Box, features_dim: int = 1):
··· def forward(self, observations: th.Tensor) -> th.Tensor:
····· policy_kwargs = dict( def get_model( def train_model( if name == "main":
`··· The error prompt is as follows: ··· |
Closing as basic rules for asking for help where not followed despite asking multiple times (#982 (comment)) |
Important Note: We do not do technical support, nor consulting and don't answer personal questions per email.
Please post your question on the RL Discord, Reddit or Stack Overflow in that case.
If your issue is related to a custom gym environment, please use the custom gym env template.
🐛 Bug
I want to customize the feature extractor. According to the program written in the example, I get the following errors. I have seen: too many errors when customizing policy, a full example for off policy algorithms should be added in user guide #425, this issue, mentioned
The off policy network should also use the feature extractor. It is recommended to give an example of off policy using the feature extractor. Thank you!
class CustomCombinedExtractor(BaseFeaturesExtractor):
def init(self, observation_space: gym.spaces.Dict):
# We do not know features-dim here before going over all the items,
# so put something dummy for now. PyTorch requires calling
# nn.Module.init before adding modules
super(CustomCombinedExtractor, self).init(observation_space, features_dim=1)
Traceback (most recent call last):
File "C:/Users/Administrator/PycharmProjects/demo/utils/models.py", line 419, in
model_sac = agent.get_model("sac", model_kwargs=SAC_PARAMS)
File "C:/Users/Administrator/PycharmProjects/demo/utils/models.py", line 328, in get_model
model = MODELS[model_name](
File "C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\sac.py", line 144, in init
self._setup_model()
File "C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\sac.py", line 147, in _setup_model
super(SAC, self)._setup_model()
File "C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 216, in _setup_model
self.policy = self.policy_class( # pytype:disable=not-instantiable
File "C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\policies.py", line 498, in init
super(MultiInputPolicy, self).init(
File "C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\policies.py", line 292, in init
self._build(lr_schedule)
File "C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\policies.py", line 295, in _build
self.actor = self.make_actor()
File "C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\policies.py", line 348, in make_actor
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
File "C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\common\policies.py", line 112, in _update_features_extractor
features_extractor = self.make_features_extractor()
File "C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\common\policies.py", line 118, in make_features_extractor
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
TypeError: init() got an unexpected keyword argument 'features_dim'
A clear and concise description of what the bug is.
To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks
for both code and stack traces.
Expected behavior
A clear and concise description of what you expected to happen.
### System Info
Describe the characteristic of your environment:
You can use
sb3.get_system_info()
to print relevant packages info:Additional context
Add any other context about the problem here.
Checklist
The text was updated successfully, but these errors were encountered: