Skip to content

Commit

Permalink
Merge pull request #119 from firemankoxd/dict-list-support
Browse files Browse the repository at this point in the history
Support for different reward function for different traffic lights
  • Loading branch information
LucasAlegre authored Oct 31, 2022
2 parents ab41de0 + 2526bc9 commit 6928a93
Showing 1 changed file with 55 additions and 20 deletions.
75 changes: 55 additions & 20 deletions sumo_rl/environment/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SumoEnvironment(gym.Env):
:param min_green: (int) Minimum green time in a phase
:param max_green: (int) Max green time in a phase
:single_agent: (bool) If true, it behaves like a regular gym.Env. Else, it behaves like a MultiagentEnv (https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/multi_agent_env.py)
:reward_fn: (str/function) String with the name of the reward function used by the agents, or a reward function.
:reward_fn: (str/function/dict) String with the name of the reward function used by the agents, a reward function, or dictionary with reward functions assigned to individual traffic lights by their keys
:add_system_info: (bool) If true, it computes system metrics (total queue, total waiting time, average speed) in the info dictionary
:add_per_agent_info: (bool) If true, it computes per-agent (per-traffic signal) metrics (average accumulated waiting time, average queue) in the info dictionary
:sumo_seed: (int/string) Random seed for sumo. If 'random' it uses a randomly chosen seed.
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
min_green: int = 5,
max_green: int = 50,
single_agent: bool = False,
reward_fn: Union[str,Callable] = 'diff-waiting-time',
reward_fn: Union[str,Callable,dict] = 'diff-waiting-time',
add_system_info: bool = True,
add_per_agent_info: bool = True,
sumo_seed: Union[str,int] = 'random',
Expand Down Expand Up @@ -133,15 +133,34 @@ def __init__(
traci.start([sumolib.checkBinary('sumo'), '-n', self._net], label='init_connection'+self.label)
conn = traci.getConnection('init_connection'+self.label)
self.ts_ids = list(conn.trafficlight.getIDList())
self.traffic_signals = {ts: TrafficSignal(self,
ts,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
self.reward_fn,
conn) for ts in self.ts_ids}

if isinstance(self.reward_fn, dict):
self.traffic_signals = dict()
for key, reward_fn_value in self.reward_fn.items():
self.traffic_signals[key] = TrafficSignal(
self,
key,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
reward_fn_value,
conn
)
else:
self.traffic_signals = {
ts: TrafficSignal(self,
ts,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
self.reward_fn,
conn) for ts in self.ts_ids
}

conn.close()

self.vehicles = dict()
Expand Down Expand Up @@ -203,15 +222,31 @@ def reset(self, seed: Optional[int] = None, **kwargs):
self.sumo_seed = seed
self._start_simulation()

self.traffic_signals = {ts: TrafficSignal(self,
ts,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
self.reward_fn,
self.sumo) for ts in self.ts_ids}
if isinstance(self.reward_fn, dict):
self.traffic_signals = dict()
for key, reward_fn_value in self.reward_fn.items():
self.traffic_signals[key] = TrafficSignal(
self,
key,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
reward_fn_value,
self.sumo
)
else:
self.traffic_signals = {ts: TrafficSignal(self,
ts,
self.delta_time,
self.yellow_time,
self.min_green,
self.max_green,
self.begin_time,
self.reward_fn,
self.sumo) for ts in self.ts_ids}

self.vehicles = dict()

if self.single_agent:
Expand Down

0 comments on commit 6928a93

Please sign in to comment.