-
Notifications
You must be signed in to change notification settings - Fork 206
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for custom observations (both observation function and observ…
…ation space) (#133) * Minor style changes to my previous implementation (to match the author's style) * Support for custom observation function, preparation for drq_norm * Working support for registerable observation functions, minor implementation changes * Moved checking the reward function to TrafficSignal constructor * Observations moved to another file * Minor fixes regarding observations implementation
- Loading branch information
1 parent
5a2d67c
commit bd8ef1d
Showing
4 changed files
with
53 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from sumo_rl.environment.env import SumoEnvironment, TrafficSignal | ||
from sumo_rl.environment.env import SumoEnvironment, TrafficSignal, ObservationFunction | ||
from sumo_rl.environment.env import env, parallel_env | ||
from sumo_rl.environment.resco_envs import grid4x4, arterial4x4, ingolstadt1, ingolstadt7, ingolstadt21, cologne1, cologne3, cologne8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from .traffic_signal import TrafficSignal | ||
from abc import abstractmethod | ||
from gymnasium import spaces | ||
import numpy as np | ||
|
||
class ObservationFunction: | ||
""" | ||
Abstract base class for observation functions. | ||
""" | ||
def __init__(self, ts: TrafficSignal): | ||
self.ts = ts | ||
|
||
@abstractmethod | ||
def __call__(self): | ||
""" | ||
Subclasses must override this method. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def observation_space(self): | ||
""" | ||
Subclasses must override this method. | ||
""" | ||
pass | ||
|
||
|
||
class DefaultObservationFunction(ObservationFunction): | ||
def __init__(self, ts: TrafficSignal): | ||
super().__init__(ts) | ||
|
||
def __call__(self): | ||
phase_id = [1 if self.ts.green_phase == i else 0 for i in range(self.ts.num_green_phases)] # one-hot encoding | ||
min_green = [0 if self.ts.time_since_last_phase_change < self.ts.min_green + self.ts.yellow_time else 1] | ||
density = self.ts.get_lanes_density() | ||
queue = self.ts.get_lanes_queue() | ||
observation = np.array(phase_id + min_green + density + queue, dtype=np.float32) | ||
return observation | ||
|
||
def observation_space(self): | ||
return spaces.Box( | ||
low=np.zeros(self.ts.num_green_phases+1+2*len(self.ts.lanes), dtype=np.float32), | ||
high=np.ones(self.ts.num_green_phases+1+2*len(self.ts.lanes), dtype=np.float32) | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters