Skip to content

Commit

Permalink
Support for custom observations (both observation function and observ…
Browse files Browse the repository at this point in the history
…ation space) (#133)

* Minor style changes to my previous implementation (to match the author's style)

* Support for custom observation function, preparation for drq_norm

* Working support for registerable observation functions, minor implementation changes

* Moved checking the reward function to TrafficSignal constructor

* Observations moved to another file

* Minor fixes regarding observations implementation
  • Loading branch information
firemankoxd authored Feb 24, 2023
1 parent 5a2d67c commit bd8ef1d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 24 deletions.
2 changes: 1 addition & 1 deletion sumo_rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sumo_rl.environment.env import SumoEnvironment, TrafficSignal
from sumo_rl.environment.env import SumoEnvironment, TrafficSignal, ObservationFunction
from sumo_rl.environment.env import env, parallel_env
from sumo_rl.environment.resco_envs import grid4x4, arterial4x4, ingolstadt1, ingolstadt7, ingolstadt21, cologne1, cologne3, cologne8
7 changes: 4 additions & 3 deletions sumo_rl/environment/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pettingzoo.utils.conversions import parallel_wrapper_fn

from .traffic_signal import TrafficSignal
from .observations import ObservationFunction, DefaultObservationFunction

LIBSUMO = 'LIBSUMO_AS_TRACI' in os.environ

Expand Down Expand Up @@ -50,7 +51,7 @@ class SumoEnvironment(gym.Env):
:param max_green: (int) Max green time in a phase
:single_agent: (bool) If true, it behaves like a regular gym.Env. Else, it behaves like a MultiagentEnv (https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/multi_agent_env.py)
:reward_fn: (str/function/dict) String with the name of the reward function used by the agents, a reward function, or dictionary with reward functions assigned to individual traffic lights by their keys
:observation_fn: (str/function) String with the name of the observation function or a callable observation function itself
:observation_class: (ObservationFunction) Inherited class which has both the observation function and observation space
:add_system_info: (bool) If true, it computes system metrics (total queue, total waiting time, average speed) in the info dictionary
:add_per_agent_info: (bool) If true, it computes per-agent (per-traffic signal) metrics (average accumulated waiting time, average queue) in the info dictionary
:sumo_seed: (int/string) Random seed for sumo. If 'random' it uses a randomly chosen seed.
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(
max_green: int = 50,
single_agent: bool = False,
reward_fn: Union[str,Callable,dict] = 'diff-waiting-time',
observation_fn: Union[str,Callable] = 'default',
observation_class: ObservationFunction = DefaultObservationFunction,
add_system_info: bool = True,
add_per_agent_info: bool = True,
sumo_seed: Union[str,int] = 'random',
Expand Down Expand Up @@ -135,7 +136,7 @@ def __init__(
traci.start([sumolib.checkBinary('sumo'), '-n', self._net], label='init_connection'+self.label)
conn = traci.getConnection('init_connection'+self.label)
self.ts_ids = list(conn.trafficlight.getIDList())
self.observation_fn = observation_fn
self.observation_class = observation_class

if isinstance(self.reward_fn, dict):
self.traffic_signals = {ts: TrafficSignal(self,
Expand Down
45 changes: 45 additions & 0 deletions sumo_rl/environment/observations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from .traffic_signal import TrafficSignal
from abc import abstractmethod
from gymnasium import spaces
import numpy as np

class ObservationFunction:
"""
Abstract base class for observation functions.
"""
def __init__(self, ts: TrafficSignal):
self.ts = ts

@abstractmethod
def __call__(self):
"""
Subclasses must override this method.
"""
pass

@abstractmethod
def observation_space(self):
"""
Subclasses must override this method.
"""
pass


class DefaultObservationFunction(ObservationFunction):
def __init__(self, ts: TrafficSignal):
super().__init__(ts)

def __call__(self):
phase_id = [1 if self.ts.green_phase == i else 0 for i in range(self.ts.num_green_phases)] # one-hot encoding
min_green = [0 if self.ts.time_since_last_phase_change < self.ts.min_green + self.ts.yellow_time else 1]
density = self.ts.get_lanes_density()
queue = self.ts.get_lanes_queue()
observation = np.array(phase_id + min_green + density + queue, dtype=np.float32)
return observation

def observation_space(self):
return spaces.Box(
low=np.zeros(self.ts.num_green_phases+1+2*len(self.ts.lanes), dtype=np.float32),
high=np.ones(self.ts.num_green_phases+1+2*len(self.ts.lanes), dtype=np.float32)
)

23 changes: 3 additions & 20 deletions sumo_rl/environment/traffic_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,7 @@ def __init__(self,
else:
raise NotImplementedError(f'Reward function {self.reward_fn} not implemented')

if isinstance(self.env.observation_fn, Callable):
self.observation_fn = self.env.observation_fn
else:
if self.env.observation_fn in TrafficSignal.observation_fns.keys():
self.observation_fn = TrafficSignal.observation_fns[self.env.observation_fn]
else:
raise NotImplementedError(f'Observation function {self.env.observation_fn} not implemented')
self.observation_fn = self.env.observation_class(self)

self.build_phases()

Expand All @@ -75,7 +69,7 @@ def __init__(self,
self.out_lanes = list(set(self.out_lanes))
self.lanes_lenght = {lane: self.sumo.lane.getLength(lane) for lane in self.lanes + self.out_lanes}

self.observation_space = spaces.Box(low=np.zeros(self.num_green_phases+1+2*len(self.lanes), dtype=np.float32), high=np.ones(self.num_green_phases+1+2*len(self.lanes), dtype=np.float32))
self.observation_space = self.observation_fn.observation_space()
self.discrete_observation_space = spaces.Tuple((
spaces.Discrete(self.num_green_phases), # Green Phase
spaces.Discrete(2), # Binary variable active if min_green seconds already elapsed
Expand Down Expand Up @@ -148,7 +142,7 @@ def set_next_phase(self, new_phase):
self.time_since_last_phase_change = 0

def compute_observation(self):
return self.observation_fn(self)
return self.observation_fn()

def compute_reward(self):
self.last_reward = self.reward_fn(self)
Expand Down Expand Up @@ -233,20 +227,9 @@ def register_reward_fn(cls, fn):

cls.reward_fns[fn.__name__] = fn

@classmethod
def register_observation_fn(cls, fn):
if fn.__name__ in cls.observation_fns.keys():
raise KeyError(f'Observation function {fn.__name__} already exists')

cls.observation_fns[fn.__name__] = fn

reward_fns = {
'diff-waiting-time': _diff_waiting_time_reward,
'average-speed': _average_speed_reward,
'queue': _queue_reward,
'pressure': _pressure_reward
}

observation_fns = {
'default': _observation_fn_default
}

0 comments on commit bd8ef1d

Please sign in to comment.