From bd8ef1d4cd68c4862ec4f925fd6e7ea0f7b2815d Mon Sep 17 00:00:00 2001 From: Filip Kalus <69034726+firemankoxd@users.noreply.github.com> Date: Fri, 24 Feb 2023 20:56:37 +0100 Subject: [PATCH] Support for custom observations (both observation function and observation space) (#133) * Minor style changes to my previous implementation (to match the author's style) * Support for custom observation function, preparation for drq_norm * Working support for registerable observation functions, minor implementation changes * Moved checking the reward function to TrafficSignal constructor * Observations moved to another file * Minor fixes regarding observations implementation --- sumo_rl/__init__.py | 2 +- sumo_rl/environment/env.py | 7 +++-- sumo_rl/environment/observations.py | 45 +++++++++++++++++++++++++++ sumo_rl/environment/traffic_signal.py | 23 ++------------ 4 files changed, 53 insertions(+), 24 deletions(-) create mode 100644 sumo_rl/environment/observations.py diff --git a/sumo_rl/__init__.py b/sumo_rl/__init__.py index 804c0204..20563346 100755 --- a/sumo_rl/__init__.py +++ b/sumo_rl/__init__.py @@ -1,3 +1,3 @@ -from sumo_rl.environment.env import SumoEnvironment, TrafficSignal +from sumo_rl.environment.env import SumoEnvironment, TrafficSignal, ObservationFunction from sumo_rl.environment.env import env, parallel_env from sumo_rl.environment.resco_envs import grid4x4, arterial4x4, ingolstadt1, ingolstadt7, ingolstadt21, cologne1, cologne3, cologne8 \ No newline at end of file diff --git a/sumo_rl/environment/env.py b/sumo_rl/environment/env.py index f575c506..e431aa70 100755 --- a/sumo_rl/environment/env.py +++ b/sumo_rl/environment/env.py @@ -19,6 +19,7 @@ from pettingzoo.utils.conversions import parallel_wrapper_fn from .traffic_signal import TrafficSignal +from .observations import ObservationFunction, DefaultObservationFunction LIBSUMO = 'LIBSUMO_AS_TRACI' in os.environ @@ -50,7 +51,7 @@ class SumoEnvironment(gym.Env): :param max_green: (int) Max green time in a phase :single_agent: (bool) If true, it behaves like a regular gym.Env. Else, it behaves like a MultiagentEnv (https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/multi_agent_env.py) :reward_fn: (str/function/dict) String with the name of the reward function used by the agents, a reward function, or dictionary with reward functions assigned to individual traffic lights by their keys - :observation_fn: (str/function) String with the name of the observation function or a callable observation function itself + :observation_class: (ObservationFunction) Inherited class which has both the observation function and observation space :add_system_info: (bool) If true, it computes system metrics (total queue, total waiting time, average speed) in the info dictionary :add_per_agent_info: (bool) If true, it computes per-agent (per-traffic signal) metrics (average accumulated waiting time, average queue) in the info dictionary :sumo_seed: (int/string) Random seed for sumo. If 'random' it uses a randomly chosen seed. @@ -83,7 +84,7 @@ def __init__( max_green: int = 50, single_agent: bool = False, reward_fn: Union[str,Callable,dict] = 'diff-waiting-time', - observation_fn: Union[str,Callable] = 'default', + observation_class: ObservationFunction = DefaultObservationFunction, add_system_info: bool = True, add_per_agent_info: bool = True, sumo_seed: Union[str,int] = 'random', @@ -135,7 +136,7 @@ def __init__( traci.start([sumolib.checkBinary('sumo'), '-n', self._net], label='init_connection'+self.label) conn = traci.getConnection('init_connection'+self.label) self.ts_ids = list(conn.trafficlight.getIDList()) - self.observation_fn = observation_fn + self.observation_class = observation_class if isinstance(self.reward_fn, dict): self.traffic_signals = {ts: TrafficSignal(self, diff --git a/sumo_rl/environment/observations.py b/sumo_rl/environment/observations.py new file mode 100644 index 00000000..99e0a8db --- /dev/null +++ b/sumo_rl/environment/observations.py @@ -0,0 +1,45 @@ +from .traffic_signal import TrafficSignal +from abc import abstractmethod +from gymnasium import spaces +import numpy as np + +class ObservationFunction: + """ + Abstract base class for observation functions. + """ + def __init__(self, ts: TrafficSignal): + self.ts = ts + + @abstractmethod + def __call__(self): + """ + Subclasses must override this method. + """ + pass + + @abstractmethod + def observation_space(self): + """ + Subclasses must override this method. + """ + pass + + +class DefaultObservationFunction(ObservationFunction): + def __init__(self, ts: TrafficSignal): + super().__init__(ts) + + def __call__(self): + phase_id = [1 if self.ts.green_phase == i else 0 for i in range(self.ts.num_green_phases)] # one-hot encoding + min_green = [0 if self.ts.time_since_last_phase_change < self.ts.min_green + self.ts.yellow_time else 1] + density = self.ts.get_lanes_density() + queue = self.ts.get_lanes_queue() + observation = np.array(phase_id + min_green + density + queue, dtype=np.float32) + return observation + + def observation_space(self): + return spaces.Box( + low=np.zeros(self.ts.num_green_phases+1+2*len(self.ts.lanes), dtype=np.float32), + high=np.ones(self.ts.num_green_phases+1+2*len(self.ts.lanes), dtype=np.float32) + ) + diff --git a/sumo_rl/environment/traffic_signal.py b/sumo_rl/environment/traffic_signal.py index 2d05b6ba..934927e5 100755 --- a/sumo_rl/environment/traffic_signal.py +++ b/sumo_rl/environment/traffic_signal.py @@ -60,13 +60,7 @@ def __init__(self, else: raise NotImplementedError(f'Reward function {self.reward_fn} not implemented') - if isinstance(self.env.observation_fn, Callable): - self.observation_fn = self.env.observation_fn - else: - if self.env.observation_fn in TrafficSignal.observation_fns.keys(): - self.observation_fn = TrafficSignal.observation_fns[self.env.observation_fn] - else: - raise NotImplementedError(f'Observation function {self.env.observation_fn} not implemented') + self.observation_fn = self.env.observation_class(self) self.build_phases() @@ -75,7 +69,7 @@ def __init__(self, self.out_lanes = list(set(self.out_lanes)) self.lanes_lenght = {lane: self.sumo.lane.getLength(lane) for lane in self.lanes + self.out_lanes} - self.observation_space = spaces.Box(low=np.zeros(self.num_green_phases+1+2*len(self.lanes), dtype=np.float32), high=np.ones(self.num_green_phases+1+2*len(self.lanes), dtype=np.float32)) + self.observation_space = self.observation_fn.observation_space() self.discrete_observation_space = spaces.Tuple(( spaces.Discrete(self.num_green_phases), # Green Phase spaces.Discrete(2), # Binary variable active if min_green seconds already elapsed @@ -148,7 +142,7 @@ def set_next_phase(self, new_phase): self.time_since_last_phase_change = 0 def compute_observation(self): - return self.observation_fn(self) + return self.observation_fn() def compute_reward(self): self.last_reward = self.reward_fn(self) @@ -233,20 +227,9 @@ def register_reward_fn(cls, fn): cls.reward_fns[fn.__name__] = fn - @classmethod - def register_observation_fn(cls, fn): - if fn.__name__ in cls.observation_fns.keys(): - raise KeyError(f'Observation function {fn.__name__} already exists') - - cls.observation_fns[fn.__name__] = fn - reward_fns = { 'diff-waiting-time': _diff_waiting_time_reward, 'average-speed': _average_speed_reward, 'queue': _queue_reward, 'pressure': _pressure_reward - } - - observation_fns = { - 'default': _observation_fn_default } \ No newline at end of file