Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

obs bug fix #112

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/pymgrid/envs/base/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
import warnings

from collections import OrderedDict
from gym import Env
Expand Down Expand Up @@ -122,7 +123,19 @@ def _validate_observation_keys(self, keys):
if net_load_pos.size:
keys[[0, net_load_pos.item()]] = keys[[net_load_pos.item(), 0]]

return keys.tolist()
unique_keys, dupe_keys = [], []

for k in keys:
if k in unique_keys:
dupe_keys.append(k)
continue

unique_keys.append(k)

if dupe_keys:
warnings.warn(f'Found duplicated keys, will be dropped:\n\t{dupe_keys}')

return unique_keys

@abstractmethod
def _get_action_space(self, remove_redundant_actions=False):
Expand All @@ -133,7 +146,12 @@ def _get_observation_space(self):

state_series = self.state_series()

observation_keys = self.observation_keys or state_series.index.get_level_values(-1)
if self.observation_keys is None or len(self.observation_keys) == 0:
observation_keys = state_series.index.get_level_values(-1)
else:
observation_keys = pd.Index(self.observation_keys)

observation_keys = observation_keys.drop_duplicates()

if 'net_load' in observation_keys:
obs_space['general'] = Tuple([Box(low=-np.inf, high=1, shape=(1, ), dtype=np.float64)])
Expand Down
89 changes: 87 additions & 2 deletions tests/envs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tests.helpers.test_case import TestCase
from tests.helpers.modular_microgrid import get_modular_microgrid

from pymgrid.modules import BatteryModule
from pymgrid.envs import DiscreteMicrogridEnv, ContinuousMicrogridEnv, NetLoadContinuousMicrogridEnv
from pymgrid.envs.base import BaseMicrogridEnv

Expand Down Expand Up @@ -131,6 +132,65 @@ class ObsKeysWithNetLoadParent(ObsKeysNoNetLoadParent):
observation_keys = ['net_load', 'soc', 'load_current', 'export_price_current']


class ObsKeysDuplicateKeysParent(ObsKeysNoNetLoadParent):
observation_keys = ['net_load', 'soc', 'load_current', 'load_current', 'export_price_current']

@pass_if_parent
def test_get_obs_correct_keys_in_modules(self):
env = deepcopy(self.env)
obs = env._get_obs()

unique_obs_keys = pd.Index(self.observation_keys).drop_duplicates().tolist()

for module in env.modules.iterlist():
module_state_dict = module.state_dict(normalized=True)
matching_keys = [obs_key for obs_key in unique_obs_keys if obs_key in module.state_dict().keys()]
matching_values = [module_state_dict[k] for k in matching_keys]

with self.subTest(module=module.name, keys=matching_keys):
self.assertEqual(obs[np.isin(unique_obs_keys, matching_keys)], matching_values)


class ObsKeysDuplicateModulesParent(Parent):

@pass_if_parent
def setUp(self) -> None:
second_battery = BatteryModule(
min_capacity=0,
max_capacity=1000,
max_charge=500,
max_discharge=500,
efficiency=1.0,
init_soc=0.5,
normalized_action_bounds=(0, 1))

microgrid = get_modular_microgrid(
additional_modules=[second_battery],
)

self.env = self.env_class.from_microgrid(microgrid, observation_keys=self.observation_keys)

@pass_if_parent
def test_pre_reset_state_series_invariant_to_observation_keys(self):
env = deepcopy(self.env)

self.assertEqual(env.state_series().shape, (15, ))

@pass_if_parent
def test_state_series_values(self):
env = deepcopy(self.env)

expected_state_series = np.array([10., -60., 50., 1., 1., 0., 0., 0.5, 50., 0.5, 500, 1., 1., 1., 1.])
self.assertEqual(env.state_series(normalized=False).values, expected_state_series)

@pass_if_parent
def test_state_series_values_normalized(self):
env = deepcopy(self.env)

expected_state_series = np.array([1/6., 0., 1., 1., 1., 0., 0., 0.5, 0.5, 0.5, 0.5, 0., 0., 0., 0.])
self.assertEqual(env.state_series(normalized=True).values, expected_state_series)


class TestDiscrete(Parent):
env_class = DiscreteMicrogridEnv

Expand All @@ -155,9 +215,34 @@ class TestNetLoadContinuousObsKeysNoNetLoad(ObsKeysNoNetLoadParent):
env_class = NetLoadContinuousMicrogridEnv


class TestDiscreteObsDuplicateKeys(ObsKeysDuplicateKeysParent):
env_class = DiscreteMicrogridEnv


class TestContinuousObsDuplicateKeys(ObsKeysDuplicateKeysParent):
env_class = ContinuousMicrogridEnv


class TestNetLoadContinuousObsDuplicateKeys(ObsKeysDuplicateKeysParent):
env_class = NetLoadContinuousMicrogridEnv


class TestDiscreteDuplicateModules(ObsKeysDuplicateModulesParent):
env_class = DiscreteMicrogridEnv


class TestContinuousDuplicateModules(ObsKeysDuplicateModulesParent):
env_class = ContinuousMicrogridEnv


class TestNetLoadContinuousDuplicateModules(ObsKeysDuplicateModulesParent):
env_class = NetLoadContinuousMicrogridEnv


def flatten_nested_dict(nested_dict):
def extract_list(l):
assert len(l) == 1, 'reduction only works with length 1 lists'
return l[0].tolist()
# assert len(l) == 1, 'reduction only works with length 1 lists'
# return l[0].tolist()
return sum([_l.tolist() for _l in l], [])

return functools.reduce(lambda x, y: x + extract_list(y), nested_dict.values(), [])
Loading