-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrps_rllib.py
31 lines (23 loc) · 1 KB
/
rps_rllib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from gym.spaces import Box
from ray.rllib.examples.env.rock_paper_scissors import RockPaperScissors
import numpy as np
class RPSNoise(RockPaperScissors):
"""RockPaperScissors with noise as observation."""
def __init__(self, config):
self.noise_dim = config.get("noise_dim", 4)
super(RPSNoise, self).__init__(config)
self.observation_space = Box(np.full(self.noise_dim, -np.inf),
np.full(self.noise_dim, np.inf))
def _sample_noise(self):
return np.random.randn(self.noise_dim)
def _transform_obs(self, obs):
noise = self._sample_noise()
return {x: noise for x, y in obs.items()}
def reset(self):
obs = super(RPSNoise, self).reset()
obs = self._transform_obs(obs)
return obs
def step(self, action_dict):
obs, rew, done, info = super(RPSNoise, self).step(action_dict)
obs = self._transform_obs(obs)
return obs, rew, done, info