Skip to content

Commit

Permalink
Tests for config_validation.py (#404)
Browse files Browse the repository at this point in the history
* Don't get default data (shouldn't be necessary)

* Unpack default data

* Refactored utils functions to return dicts instead of strings

* Adapted to new mock format

* Adapted to new mock format

* Some first tests

* More tests

* Fixed testcase-names

* Moved file-endings to initial function call

* Fixed tests

* More asserts

* More tests

* More tests

* More tests

* `validate_sub_keys`-tests
  • Loading branch information
NikkelM authored Apr 11, 2022
1 parent f8bc162 commit 091eee6
Show file tree
Hide file tree
Showing 16 changed files with 693 additions and 201 deletions.
8 changes: 4 additions & 4 deletions badges/docstring_coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 9 additions & 7 deletions recommerce/configuration/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def validate_config(config: dict, config_is_final: bool) -> tuple:
raise AssertionError('If your config contains one of "environment" or "hyperparameter" it must also contain the other')
else:
# try to split the config. If any keys are unknown, an AssertionError will be thrown
hyperparameter_config, environment_config = split_combined_config(config)
hyperparameter_config, environment_config = split_mixed_config(config)
# then validate that all given values have the correct types
check_config_types(hyperparameter_config, environment_config, config_is_final)

Expand Down Expand Up @@ -57,9 +57,11 @@ def validate_sub_keys(config_class: HyperparameterConfig or EnvironmentConfig, c
for key, _ in config.items():
# we need to separately check agents, since it is a list of dictionaries
if key == 'agents':
assert isinstance(config['agents'], list), f'The "agents" key must have a value of type list, but was {type(config["agents"])}'
for agent in config['agents']:
assert isinstance(agent, dict), f'All agents must be of type dict, but this one was {type(agent)}'
assert all(agent_key in {'name', 'agent_class', 'argument'} for agent_key in agent.keys()), \
f'an invalid key for agents was provided: {agent.keys()}'
f'An invalid key for agents was provided: {agent.keys()}'
# the key is key of a dictionary in the config
elif top_level_keys[key]:
assert isinstance(config[key], dict), f'The value of this key must be of type dict: {key}, but was {type(config[key])}'
Expand All @@ -73,13 +75,13 @@ def validate_sub_keys(config_class: HyperparameterConfig or EnvironmentConfig, c
validate_sub_keys(config_class, config[key], key_fields)


def split_combined_config(config: dict) -> tuple:
def split_mixed_config(config: dict) -> tuple:
"""
Utility function that splits a potentially combined config of hyperparameters and environment-variables
Utility function that splits a potentially mixed config of hyperparameters and environment-variables
into two dictionaries for the two configurations.
Args:
config (dict): The potentially combined configuration.
config (dict): The potentially mixed configuration.
Returns:
dict: The hyperparameter_config
Expand Down Expand Up @@ -118,7 +120,7 @@ def check_config_types(hyperparameter_config: dict, environment_config: dict, mu
must_contain (bool): Whether or not the configuration should contain all required keys.
Raises:
AssertionError: If one of the values has the wring type.
AssertionError: If one of the values has the wrong type.
"""
# check types for hyperparameter_config
HyperparameterConfig.check_types(hyperparameter_config, 'top-dict', must_contain)
Expand Down Expand Up @@ -167,5 +169,5 @@ def check_config_types(hyperparameter_config: dict, environment_config: dict, mu
}
]
}
hyper, env = split_combined_config(test_config)
hyper, env = split_mixed_config(test_config)
check_config_types(hyper, env)
5 changes: 0 additions & 5 deletions recommerce/configuration/environment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,6 @@ def get_required_fields(cls, dict_key) -> dict:
'marketplace': False,
'agents': False
}
elif dict_key == 'agents':
return {
'agent_class': False,
'argument': False
}
else:
raise AssertionError(f'The given level does not exist in an environment-config: {dict_key}')

Expand Down
3 changes: 1 addition & 2 deletions recommerce/monitoring/agent_monitoring/am_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def evaluate_session(self, rewards: list, episode_numbers: list = None):
print(f'All plots were saved to {os.path.abspath(self.configurator.folder_path)}')

# visualize metrics
def create_histogram(self, rewards: list, is_last_histogram: bool, filename: str = 'default') -> None:
def create_histogram(self, rewards: list, is_last_histogram: bool, filename: str = 'default_histogram.svg') -> None:
"""
Create a histogram sorting rewards into bins of 1000.
Expand All @@ -64,7 +64,6 @@ def create_histogram(self, rewards: list, is_last_histogram: bool, filename: str
return
assert all(len(curr_reward) == len(rewards[0]) for curr_reward in rewards), 'all rewards-arrays must be of the same size'

filename += '.svg'
plt.clf()
plt.xlabel('Reward', fontsize='18')
plt.ylabel('Episodes', fontsize='18')
Expand Down
4 changes: 2 additions & 2 deletions recommerce/monitoring/agent_monitoring/am_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def run_marketplace(self) -> list:
rewards[current_agent_index] += [episode_reward]

if (episode % self.configurator.plot_interval) == 0:
self.evaluator.create_histogram(rewards, False, f'episode_{episode}')
self.evaluator.create_histogram(rewards, False, f'episode_{episode}.svg')

# only one histogram after the whole monitoring process
self.evaluator.create_histogram(rewards, True, 'Cumulative_rewards_per_episode')
self.evaluator.create_histogram(rewards, True, 'Cumulative_rewards_per_episode.svg')

return rewards

Expand Down
2 changes: 1 addition & 1 deletion tests/test_agent_monitoring/test_am_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_create_histogram(agents, rewards, plot_bins, agent_color, lower_upper_r
hist_mock.assert_called_once_with(rewards, bins=plot_bins, color=agent_color, rwidth=0.9, range=lower_upper_range, edgecolor='black')
legend_mock.assert_called_once_with(name_list)
draw_mock.assert_called_once()
save_mock.assert_called_once_with(fname=os.path.join(monitor.configurator.folder_path, 'default.svg'))
save_mock.assert_called_once_with(fname=os.path.join(monitor.configurator.folder_path, 'default_histogram.svg'))


def test_create_histogram_without_saving_to_directory():
Expand Down
Loading

0 comments on commit 091eee6

Please sign in to comment.