diff --git a/badges/docstring_coverage.svg b/badges/docstring_coverage.svg index 20b60d7a..7192e4f2 100644 --- a/badges/docstring_coverage.svg +++ b/badges/docstring_coverage.svg @@ -1,5 +1,5 @@ - - interrogate: 56.8% + + interrogate: 56.3% @@ -15,8 +15,8 @@ interrogate - - 56.8% + + 56.3% diff --git a/recommerce/configuration/config_validation.py b/recommerce/configuration/config_validation.py index e350f431..473b3b6c 100644 --- a/recommerce/configuration/config_validation.py +++ b/recommerce/configuration/config_validation.py @@ -26,7 +26,7 @@ def validate_config(config: dict, config_is_final: bool) -> tuple: raise AssertionError('If your config contains one of "environment" or "hyperparameter" it must also contain the other') else: # try to split the config. If any keys are unknown, an AssertionError will be thrown - hyperparameter_config, environment_config = split_combined_config(config) + hyperparameter_config, environment_config = split_mixed_config(config) # then validate that all given values have the correct types check_config_types(hyperparameter_config, environment_config, config_is_final) @@ -57,9 +57,11 @@ def validate_sub_keys(config_class: HyperparameterConfig or EnvironmentConfig, c for key, _ in config.items(): # we need to separately check agents, since it is a list of dictionaries if key == 'agents': + assert isinstance(config['agents'], list), f'The "agents" key must have a value of type list, but was {type(config["agents"])}' for agent in config['agents']: + assert isinstance(agent, dict), f'All agents must be of type dict, but this one was {type(agent)}' assert all(agent_key in {'name', 'agent_class', 'argument'} for agent_key in agent.keys()), \ - f'an invalid key for agents was provided: {agent.keys()}' + f'An invalid key for agents was provided: {agent.keys()}' # the key is key of a dictionary in the config elif top_level_keys[key]: assert isinstance(config[key], dict), f'The value of this key must be of type dict: {key}, but was {type(config[key])}' @@ -73,13 +75,13 @@ def validate_sub_keys(config_class: HyperparameterConfig or EnvironmentConfig, c validate_sub_keys(config_class, config[key], key_fields) -def split_combined_config(config: dict) -> tuple: +def split_mixed_config(config: dict) -> tuple: """ - Utility function that splits a potentially combined config of hyperparameters and environment-variables + Utility function that splits a potentially mixed config of hyperparameters and environment-variables into two dictionaries for the two configurations. Args: - config (dict): The potentially combined configuration. + config (dict): The potentially mixed configuration. Returns: dict: The hyperparameter_config @@ -118,7 +120,7 @@ def check_config_types(hyperparameter_config: dict, environment_config: dict, mu must_contain (bool): Whether or not the configuration should contain all required keys. Raises: - AssertionError: If one of the values has the wring type. + AssertionError: If one of the values has the wrong type. """ # check types for hyperparameter_config HyperparameterConfig.check_types(hyperparameter_config, 'top-dict', must_contain) @@ -167,5 +169,5 @@ def check_config_types(hyperparameter_config: dict, environment_config: dict, mu } ] } - hyper, env = split_combined_config(test_config) + hyper, env = split_mixed_config(test_config) check_config_types(hyper, env) diff --git a/recommerce/configuration/environment_config.py b/recommerce/configuration/environment_config.py index 9f033641..e17e0969 100644 --- a/recommerce/configuration/environment_config.py +++ b/recommerce/configuration/environment_config.py @@ -66,11 +66,6 @@ def get_required_fields(cls, dict_key) -> dict: 'marketplace': False, 'agents': False } - elif dict_key == 'agents': - return { - 'agent_class': False, - 'argument': False - } else: raise AssertionError(f'The given level does not exist in an environment-config: {dict_key}') diff --git a/recommerce/monitoring/agent_monitoring/am_evaluation.py b/recommerce/monitoring/agent_monitoring/am_evaluation.py index bb419671..819d3217 100644 --- a/recommerce/monitoring/agent_monitoring/am_evaluation.py +++ b/recommerce/monitoring/agent_monitoring/am_evaluation.py @@ -51,7 +51,7 @@ def evaluate_session(self, rewards: list, episode_numbers: list = None): print(f'All plots were saved to {os.path.abspath(self.configurator.folder_path)}') # visualize metrics - def create_histogram(self, rewards: list, is_last_histogram: bool, filename: str = 'default') -> None: + def create_histogram(self, rewards: list, is_last_histogram: bool, filename: str = 'default_histogram.svg') -> None: """ Create a histogram sorting rewards into bins of 1000. @@ -64,7 +64,6 @@ def create_histogram(self, rewards: list, is_last_histogram: bool, filename: str return assert all(len(curr_reward) == len(rewards[0]) for curr_reward in rewards), 'all rewards-arrays must be of the same size' - filename += '.svg' plt.clf() plt.xlabel('Reward', fontsize='18') plt.ylabel('Episodes', fontsize='18') diff --git a/recommerce/monitoring/agent_monitoring/am_monitoring.py b/recommerce/monitoring/agent_monitoring/am_monitoring.py index f3d8de2d..a1f9675f 100644 --- a/recommerce/monitoring/agent_monitoring/am_monitoring.py +++ b/recommerce/monitoring/agent_monitoring/am_monitoring.py @@ -70,10 +70,10 @@ def run_marketplace(self) -> list: rewards[current_agent_index] += [episode_reward] if (episode % self.configurator.plot_interval) == 0: - self.evaluator.create_histogram(rewards, False, f'episode_{episode}') + self.evaluator.create_histogram(rewards, False, f'episode_{episode}.svg') # only one histogram after the whole monitoring process - self.evaluator.create_histogram(rewards, True, 'Cumulative_rewards_per_episode') + self.evaluator.create_histogram(rewards, True, 'Cumulative_rewards_per_episode.svg') return rewards diff --git a/tests/test_agent_monitoring/test_am_evaluation.py b/tests/test_agent_monitoring/test_am_evaluation.py index 43111240..8eb68597 100644 --- a/tests/test_agent_monitoring/test_am_evaluation.py +++ b/tests/test_agent_monitoring/test_am_evaluation.py @@ -92,7 +92,7 @@ def test_create_histogram(agents, rewards, plot_bins, agent_color, lower_upper_r hist_mock.assert_called_once_with(rewards, bins=plot_bins, color=agent_color, rwidth=0.9, range=lower_upper_range, edgecolor='black') legend_mock.assert_called_once_with(name_list) draw_mock.assert_called_once() - save_mock.assert_called_once_with(fname=os.path.join(monitor.configurator.folder_path, 'default.svg')) + save_mock.assert_called_once_with(fname=os.path.join(monitor.configurator.folder_path, 'default_histogram.svg')) def test_create_histogram_without_saving_to_directory(): diff --git a/tests/test_config_validation.py b/tests/test_config_validation.py new file mode 100644 index 00000000..eda738e7 --- /dev/null +++ b/tests/test_config_validation.py @@ -0,0 +1,387 @@ +import pytest +import utils_tests as ut_t + +import recommerce.configuration.config_validation as config_validation +from recommerce.configuration.environment_config import EnvironmentConfig +from recommerce.configuration.hyperparameter_config import HyperparameterConfig + +########## +# Tests with already combined configs (== hyperparameter and/or environment key on the top-level) +########## +validate_config_valid_combined_final_testcases = [ + ut_t.create_combined_mock_dict(), + ut_t.create_combined_mock_dict(hyperparameter=ut_t.create_hyperparameter_mock_dict(rl=ut_t.create_hyperparameter_mock_dict_rl(gamma=0.5))), + ut_t.create_combined_mock_dict(hyperparameter=ut_t.create_hyperparameter_mock_dict( + sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(max_price=25))), + ut_t.create_combined_mock_dict(environment=ut_t.create_environment_mock_dict(task='exampleprinter')), + ut_t.create_combined_mock_dict(environment=ut_t.create_environment_mock_dict(agents=[ + { + 'name': 'Test_agent', + 'agent_class': 'recommerce.rl.q_learning.q_learning_agent.QLearningCERebuyAgent', + 'argument': '' + }, + { + 'name': 'Test_agent2', + 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + 'argument': '' + } + ])), +] + + +@pytest.mark.parametrize('config', validate_config_valid_combined_final_testcases) +def test_validate_config_valid_combined_final(config): + # If the config is valid, the first member of the tuple returned will be True + validate_status, validate_data = config_validation.validate_config(config, True) + assert validate_status, validate_data + assert isinstance(validate_data, tuple) + assert 'rl' in validate_data[0] + assert 'sim_market' in validate_data[0] + assert 'gamma' in validate_data[0]['rl'] + assert 'max_price' in validate_data[0]['sim_market'] + assert 'task' in validate_data[1] + assert 'agents' in validate_data[1] + + +# These testcases do not cover everything, nor should they, there are simply too many combinations +validate_config_valid_combined_not_final_testcases = [ + ut_t.create_combined_mock_dict( + hyperparameter=ut_t.remove_key('rl', ut_t.create_hyperparameter_mock_dict())), + ut_t.create_combined_mock_dict( + hyperparameter=ut_t.create_hyperparameter_mock_dict( + rl=ut_t.remove_key('learning_rate', ut_t.create_hyperparameter_mock_dict_rl(gamma=0.5)))), + ut_t.create_combined_mock_dict( + hyperparameter=ut_t.create_hyperparameter_mock_dict( + rl=ut_t.remove_key('epsilon_start', ut_t.remove_key('learning_rate', ut_t.create_hyperparameter_mock_dict_rl())))), + ut_t.create_combined_mock_dict(environment=ut_t.remove_key('task', ut_t.create_environment_mock_dict())), + ut_t.create_combined_mock_dict(environment=ut_t.remove_key('agents', ut_t.remove_key('task', ut_t.create_environment_mock_dict()))), +] + validate_config_valid_combined_final_testcases + + +@pytest.mark.parametrize('config', validate_config_valid_combined_not_final_testcases) +def test_validate_config_valid_combined_not_final(config): + # If the config is valid, the first member of the returned tuple will be True + validate_status, validate_data = config_validation.validate_config(config, False) + assert validate_status, validate_data + + +validate_config_one_top_key_missing_testcases = [ + (ut_t.create_combined_mock_dict(hyperparameter=None), True), + (ut_t.create_combined_mock_dict(environment=None), True), + (ut_t.create_combined_mock_dict(hyperparameter=None), False), + (ut_t.create_combined_mock_dict(environment=None), False) +] + + +@pytest.mark.parametrize('config, is_final', validate_config_one_top_key_missing_testcases) +def test_validate_config_one_top_key_missing(config, is_final): + validate_status, validate_data = config_validation.validate_config(config, is_final) + assert not validate_status, validate_data + assert 'If your config contains one of "environment" or "hyperparameter" it must also contain the other' == validate_data + + +validate_config_too_many_keys_testcases = [ + True, + False +] + + +@pytest.mark.parametrize('is_final', validate_config_too_many_keys_testcases) +def test_validate_config_too_many_keys(is_final): + test_config = ut_t.create_combined_mock_dict() + test_config['additional_key'] = "this should'nt be allowed" + validate_status, validate_data = config_validation.validate_config(test_config, is_final) + assert not validate_status, validate_data + assert 'Your config should not contain keys other than "environment" and "hyperparameter"' == validate_data +########## +# End of tests with already combined configs (== hyperparameter and/or environment key on the top-level) +########## + + +########## +# Tests without the already split top-level (config keys are mixed and need to be matched) +########## +# These are singular dicts that will get combined for the actual testcases +validate_config_valid_not_final_dicts = [ + { + 'rl': { + 'gamma': 0.5, + 'epsilon_start': 0.9 + } + }, + { + 'sim_market': { + 'max_price': 40 + } + }, + { + 'task': 'training' + }, + { + 'marketplace': 'recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopolyScenario' + }, + { + 'agents': [ + { + 'name': 'Rule_Based Agent', + 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + 'argument': '' + }, + { + 'name': 'CE Rebuy Agent (QLearning)', + 'agent_class': 'recommerce.rl.q_learning.q_learning_agent.QLearningCERebuyAgent', + 'argument': 'CircularEconomyRebuyPriceMonopolyScenario_QLearningCERebuyAgent.dat' + } + ] + }, + { + 'agents': [ + { + 'name': 'Rule_Based Agent', + 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + 'argument': '' + } + ] + } +] + + +# get all combinations of the dicts defined above to mix and match as much as possible +mixed_configs = [ + {**dict1, **dict2} for dict1 in validate_config_valid_not_final_dicts for dict2 in validate_config_valid_not_final_dicts +] + + +@pytest.mark.parametrize('config', mixed_configs) +def test_validate_config_valid_not_final(config): + validate_status, validate_data = config_validation.validate_config(config, False) + assert validate_status, f'Test failed with error: {validate_data} on config: {config}' + + +validate_config_valid_final_testcases = [ + {**ut_t.create_hyperparameter_mock_dict(), **ut_t.create_environment_mock_dict()}, + {**ut_t.create_hyperparameter_mock_dict(rl=ut_t.create_hyperparameter_mock_dict_rl(gamma=0.2)), **ut_t.create_environment_mock_dict()}, + {**ut_t.create_hyperparameter_mock_dict(), **ut_t.create_environment_mock_dict(episodes=20)} +] + + +@pytest.mark.parametrize('config', validate_config_valid_final_testcases) +def test_validate_config_valid_final(config): + validate_status, validate_data = config_validation.validate_config(config, True) + assert validate_status, f'Test failed with error: {validate_data} on config: {config}' + assert 'rl' in validate_data[0] + assert 'sim_market' in validate_data[0] + assert 'agents' in validate_data[1] + + +@pytest.mark.parametrize('config', mixed_configs) +def test_split_mixed_config_valid(config): + config_validation.split_mixed_config(config) + + +split_mixed_config_invalid_testcases = [ + { + 'invalid_key': 2 + }, + { + 'rl': { + 'gamma': 0.5 + }, + 'invalid_key': 2 + }, + { + 'agents': [ + { + 'name': 'test', + 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + 'argument': '' + } + ], + 'invalid_key': 2 + } +] + + +@pytest.mark.parametrize('config', split_mixed_config_invalid_testcases) +def test_split_mixed_config_invalid(config): + with pytest.raises(AssertionError) as error_message: + config_validation.split_mixed_config(config) + assert 'Your config contains an invalid key:' in str(error_message.value) + + +validate_sub_keys_invalid_keys_hyperparameter_testcases = [ + { + 'rl': { + 'gamma': 0.5, + 'invalid_key': 2 + } + }, + { + 'sim_market': { + 'max_price': 50, + 'invalid_key': 2 + } + }, + { + 'rl': { + 'gamma': 0.5, + 'invalid_key': 2 + }, + 'sim_market': { + 'max_price': 50, + 'invalid_key': 2 + } + }, + { + 'rl': { + 'gamma': 0.5 + }, + 'sim_market': { + 'max_price': 50, + 'invalid_key': 2 + } + } +] + + +@pytest.mark.parametrize('config', validate_sub_keys_invalid_keys_hyperparameter_testcases) +def test_validate_sub_keys_invalid_keys_hyperparameter(config): + with pytest.raises(AssertionError) as error_message: + top_level_keys = HyperparameterConfig.get_required_fields('top-dict') + config_validation.validate_sub_keys(HyperparameterConfig, config, top_level_keys) + assert 'The key "invalid_key" should not exist within a HyperparameterConfig config' in str(error_message.value) + + +validate_sub_keys_agents_invalid_keys_testcases = [ + { + 'task': 'training', + 'agents': [ + { + 'name': 'name', + 'invalid_key': 2 + } + ] + }, + { + 'agents': [ + { + 'name': '', + 'argument': '', + 'invalid_key': 2 + } + ] + }, + { + 'agents': [ + { + 'argument': '' + }, + { + 'name': '', + 'agent_class': '', + 'argument': '', + 'invalid_key': 2 + } + ] + } +] + + +@pytest.mark.parametrize('config', validate_sub_keys_agents_invalid_keys_testcases) +def test_validate_sub_keys_agents_invalid_keys(config): + with pytest.raises(AssertionError) as error_message: + top_level_keys = EnvironmentConfig.get_required_fields('top-dict') + config_validation.validate_sub_keys(EnvironmentConfig, config, top_level_keys) + assert 'An invalid key for agents was provided:' in str(error_message.value) + + +validate_sub_keys_agents_wrong_type_testcases = [ + { + 'agents': 2 + }, + { + 'agents': 'string' + }, + { + 'agents': 2.0 + }, + { + 'agents': {} + } +] + + +@pytest.mark.parametrize('config', validate_sub_keys_agents_wrong_type_testcases) +def test_validate_sub_keys_agents_wrong_type(config): + with pytest.raises(AssertionError) as error_message: + top_level_keys = EnvironmentConfig.get_required_fields('top-dict') + config_validation.validate_sub_keys(EnvironmentConfig, config, top_level_keys) + assert 'The "agents" key must have a value of type list, but was' in str(error_message.value) + + +validate_sub_keys_agents_wrong_type_testcases = [ + { + 'agents': [ + 2 + ] + }, + { + 'agents': [ + 'string' + ] + }, + { + 'agents': [ + 2.0 + ] + }, + { + 'agents': [ + [] + ] + } +] + + +@pytest.mark.parametrize('config', validate_sub_keys_agents_wrong_type_testcases) +def test_validate_sub_keys_agents_wrong_subtype(config): + with pytest.raises(AssertionError) as error_message: + top_level_keys = EnvironmentConfig.get_required_fields('top-dict') + config_validation.validate_sub_keys(EnvironmentConfig, config, top_level_keys) + assert 'All agents must be of type dict, but this one was' in str(error_message.value) + + +validate_sub_keys_wrong_type_hyperparameter_testcases = [ + { + 'rl': [] + }, + { + 'sim_market': [] + }, + { + 'rl': 2 + }, + { + 'sim_market': 2 + }, + { + 'rl': 'string' + }, + { + 'sim_market': 'string' + }, + { + 'rl': 2.0 + }, + { + 'sim_market': 2.0 + }, +] + + +@pytest.mark.parametrize('config', validate_sub_keys_wrong_type_hyperparameter_testcases) +def test_validate_sub_keys_wrong_type_hyperparameter(config): + with pytest.raises(AssertionError) as error_message: + top_level_keys = HyperparameterConfig.get_required_fields('top-dict') + config_validation.validate_sub_keys(HyperparameterConfig, config, top_level_keys) + assert 'The value of this key must be of type dict:' in str(error_message.value) diff --git a/tests/test_environment_config.py b/tests/test_environment_config.py index 7df1c71a..b0367e10 100644 --- a/tests/test_environment_config.py +++ b/tests/test_environment_config.py @@ -16,6 +16,24 @@ def test_abstract_parent_class(): assert "Can't instantiate abstract class EnvironmentConfig" in str(error_message.value) +def test_get_required_fields_valid(): + fields = env_config.EnvironmentConfig.get_required_fields('top-dict') + assert fields == { + 'task': False, + 'enable_live_draw': False, + 'episodes': False, + 'plot_interval': False, + 'marketplace': False, + 'agents': False + } + + +def test_get_required_fields_invalid(): + with pytest.raises(AssertionError) as error_message: + env_config.EnvironmentConfig.get_required_fields('wrong_key') + assert 'The given level does not exist in an environment-config: wrong_key' in str(error_message.value) + + def test_str_representation(): test_dict = { 'task': 'training', diff --git a/tests/test_hyperparameter_config_rl.py b/tests/test_hyperparameter_config_rl.py index 9d1e7489..e705c2dd 100644 --- a/tests/test_hyperparameter_config_rl.py +++ b/tests/test_hyperparameter_config_rl.py @@ -1,3 +1,4 @@ +import json from importlib import reload from unittest.mock import mock_open, patch @@ -23,13 +24,56 @@ def import_config() -> hyperparameter_config.HyperparameterConfig: return hyperparameter_config.config +###### +# General tests for the HyperParameter parent class +##### +get_required_fields_valid_testcases = [ + ('top-dict', {'rl': True, 'sim_market': True}), + ('rl', { + 'gamma': False, + 'batch_size': False, + 'replay_size': False, + 'learning_rate': False, + 'sync_target_frames': False, + 'replay_start_size': False, + 'epsilon_decay_last_frame': False, + 'epsilon_start': False, + 'epsilon_final': False + }), + ('sim_market', { + 'max_storage': False, + 'episode_length': False, + 'max_price': False, + 'max_quality': False, + 'number_of_customers': False, + 'production_price': False, + 'storage_cost_per_product': False + }) +] + + +@pytest.mark.parametrize('level, expected_dict', get_required_fields_valid_testcases) +def test_get_required_fields_valid(level, expected_dict): + fields = hyperparameter_config.HyperparameterConfig.get_required_fields(level) + assert fields == expected_dict + + +def test_get_required_fields_invalid(): + with pytest.raises(AssertionError) as error_message: + hyperparameter_config.HyperparameterConfig.get_required_fields('wrong_key') + assert 'The given level does not exist in a hyperparameter-config: wrong_key' in str(error_message.value) +###### +# End general tests +##### + + # mock format taken from: # https://stackoverflow.com/questions/1289894/how-do-i-mock-an-open-used-in-a-with-statement-using-the-mock-framework-in-pyth # Test that checks if the config.json is read correctly def test_reading_file_values(): - json = ut_t.create_hyperparameter_mock_json() - with patch('builtins.open', mock_open(read_data=json)) as mock_file: - ut_t.check_mock_file(mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict()) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) config = import_config() @@ -44,9 +88,9 @@ def test_reading_file_values(): assert config.epsilon_final == 0.1 # Test a second time with other values to ensure that the values are read correctly - json2 = ut_t.create_hyperparameter_mock_json(rl=ut_t.create_hyperparameter_mock_json_rl(learning_rate='1e-4')) - with patch('builtins.open', mock_open(read_data=json2)) as mock_file: - ut_t.check_mock_file(mock_file, json2) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict(rl=ut_t.create_hyperparameter_mock_dict_rl(learning_rate=1e-4))) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) config = import_config() @@ -55,35 +99,37 @@ def test_reading_file_values(): # The following variables are input mock-json strings for the test_invalid_values test # These tests have invalid values in their input file, the import should throw a specific error message -learning_rate_larger_one = (ut_t.create_hyperparameter_mock_json_rl(learning_rate='1.5'), +learning_rate_larger_one = (ut_t.create_hyperparameter_mock_dict_rl(learning_rate=1.5), 'learning_rate should be between 0 and 1 (excluded)') -negative_learning_rate = (ut_t.create_hyperparameter_mock_json_rl(learning_rate='0'), 'learning_rate should be between 0 and 1 (excluded)') -large_gamma = (ut_t.create_hyperparameter_mock_json_rl(gamma='1.0'), 'gamma should be between 0 (included) and 1 (excluded)') -negative_gamma = ((ut_t.create_hyperparameter_mock_json_rl(gamma='-1.0'), 'gamma should be between 0 (included) and 1 (excluded)')) -negative_batch_size = (ut_t.create_hyperparameter_mock_json_rl(batch_size='-5'), 'batch_size should be greater than 0') -negative_replay_size = (ut_t.create_hyperparameter_mock_json_rl(replay_size='-5'), +negative_learning_rate = (ut_t.create_hyperparameter_mock_dict_rl(learning_rate=0), 'learning_rate should be between 0 and 1 (excluded)') +large_gamma = (ut_t.create_hyperparameter_mock_dict_rl(gamma=1.0), 'gamma should be between 0 (included) and 1 (excluded)') +negative_gamma = ((ut_t.create_hyperparameter_mock_dict_rl(gamma=-1.0), 'gamma should be between 0 (included) and 1 (excluded)')) +negative_batch_size = (ut_t.create_hyperparameter_mock_dict_rl(batch_size=-5), 'batch_size should be greater than 0') +negative_replay_size = (ut_t.create_hyperparameter_mock_dict_rl(replay_size=-5), 'replay_size should be greater than 0') -negative_sync_target_frames = (ut_t.create_hyperparameter_mock_json_rl(sync_target_frames='-5'), +negative_sync_target_frames = (ut_t.create_hyperparameter_mock_dict_rl(sync_target_frames=-5), 'sync_target_frames should be greater than 0') -negative_replay_start_size = (ut_t.create_hyperparameter_mock_json_rl(replay_start_size='-5'), 'replay_start_size should be greater than 0') -negative_epsilon_decay_last_frame = (ut_t.create_hyperparameter_mock_json_rl(epsilon_decay_last_frame='-5'), +negative_replay_start_size = (ut_t.create_hyperparameter_mock_dict_rl(replay_start_size=-5), 'replay_start_size should be greater than 0') +negative_epsilon_decay_last_frame = (ut_t.create_hyperparameter_mock_dict_rl(epsilon_decay_last_frame=-5), 'epsilon_decay_last_frame should not be negative') # These tests are missing a line in the config file, the import should throw a specific error message -missing_gamma = (ut_t.remove_line(0, ut_t.create_hyperparameter_mock_json_rl()), 'your config_rl is missing gamma') -missing_batch_size = (ut_t.remove_line(1, ut_t.create_hyperparameter_mock_json_rl()), 'your config_rl is missing batch_size') -missing_replay_size = (ut_t.remove_line(2, ut_t.create_hyperparameter_mock_json_rl()), 'your config_rl is missing replay_size') -missing_learning_rate = (ut_t.remove_line(3, ut_t.create_hyperparameter_mock_json_rl()), +missing_gamma = (ut_t.remove_key('gamma', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing gamma') +missing_batch_size = (ut_t.remove_key('batch_size', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing batch_size') +missing_replay_size = (ut_t.remove_key('replay_size', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing replay_size') +missing_learning_rate = (ut_t.remove_key('learning_rate', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing learning_rate') -missing_sync_target_frames = (ut_t.remove_line(4, ut_t.create_hyperparameter_mock_json_rl()), +missing_sync_target_frames = (ut_t.remove_key('sync_target_frames', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing sync_target_frames') -missing_replay_start_size = (ut_t.remove_line(5, ut_t.create_hyperparameter_mock_json_rl()), +missing_replay_start_size = (ut_t.remove_key('replay_start_size', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing replay_start_size') -missing_epsilon_decay_last_frame = (ut_t.remove_line(6, ut_t.create_hyperparameter_mock_json_rl()), +missing_epsilon_decay_last_frame = (ut_t.remove_key('epsilon_decay_last_frame', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing epsilon_decay_last_frame') -missing_epsilon_start = (ut_t.remove_line(7, ut_t.create_hyperparameter_mock_json_rl()), 'your config_rl is missing epsilon_start') -missing_epsilon_final = (ut_t.remove_line(8, ut_t.create_hyperparameter_mock_json_rl()), 'your config_rl is missing epsilon_final') +missing_epsilon_start = (ut_t.remove_key('epsilon_start', ut_t.create_hyperparameter_mock_dict_rl()), + 'your config_rl is missing epsilon_start') +missing_epsilon_final = (ut_t.remove_key('epsilon_final', ut_t.create_hyperparameter_mock_dict_rl()), + 'your config_rl is missing epsilon_final') invalid_values_testcases = [ @@ -111,9 +157,9 @@ def test_reading_file_values(): # Test that checks that an invalid/broken config.json gets detected correctly @pytest.mark.parametrize('rl_json, expected_message', invalid_values_testcases) def test_invalid_values(rl_json, expected_message): - json = ut_t.create_hyperparameter_mock_json(rl=rl_json) - with patch('builtins.open', mock_open(read_data=json)) as mock_file: - ut_t.check_mock_file(mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict(rl=rl_json)) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) with pytest.raises(AssertionError) as assertion_message: import_config() assert expected_message in str(assertion_message.value) diff --git a/tests/test_hyperparameter_config_sim_market.py b/tests/test_hyperparameter_config_sim_market.py index a59738ba..a086a097 100644 --- a/tests/test_hyperparameter_config_sim_market.py +++ b/tests/test_hyperparameter_config_sim_market.py @@ -1,3 +1,4 @@ +import json from importlib import reload from unittest.mock import mock_open, patch @@ -27,9 +28,9 @@ def import_config() -> hyperparameter_config.HyperparameterConfig: # https://stackoverflow.com/questions/1289894/how-do-i-mock-an-open-used-in-a-with-statement-using-the-mock-framework-in-pyth # Test that checks if the config.json is read correctly def test_reading_file_values(): - json = ut_t.create_hyperparameter_mock_json(sim_market=ut_t.create_hyperparameter_mock_json_sim_market()) - with patch('builtins.open', mock_open(read_data=json)) as mock_file: - ut_t.check_mock_file(mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict(sim_market=ut_t.create_hyperparameter_mock_dict_sim_market())) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) config = import_config() @@ -42,10 +43,10 @@ def test_reading_file_values(): assert config.storage_cost_per_product == 0.3 # Test a second time with other values to ensure, that the values are read correctly - json2 = ut_t.create_hyperparameter_mock_json( - sim_market=ut_t.create_hyperparameter_mock_json_sim_market('50', '50', '50', '80', '20', '10', '0.7')) - with patch('builtins.open', mock_open(read_data=json2)) as mock_file: - ut_t.check_mock_file(mock_file, json2) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict( + sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(50, 50, 50, 80, 20, 10, 0.7))) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) config = import_config() @@ -60,29 +61,33 @@ def test_reading_file_values(): # The following variables are input mock-json strings for the test_invalid_values test # These tests have invalid values in their input file, the import should throw a specific error message -odd_number_of_customers = (ut_t.create_hyperparameter_mock_json_sim_market(number_of_customers='21'), +odd_number_of_customers = (ut_t.create_hyperparameter_mock_dict_sim_market(number_of_customers=21), 'number_of_customers should be even and positive') -negative_number_of_customers = (ut_t.create_hyperparameter_mock_json_sim_market('10', '50', '50', '80', '-10', '10', '0.15'), +negative_number_of_customers = (ut_t.create_hyperparameter_mock_dict_sim_market(10, 50, 50, 80, -10, 10, 0.15), 'number_of_customers should be even and positive') -prod_price_higher_max_price = (ut_t.create_hyperparameter_mock_json_sim_market('10', '50', '10', '80', '20', '50', '0.15'), +prod_price_higher_max_price = (ut_t.create_hyperparameter_mock_dict_sim_market(10, 50, 10, 80, 20, 50, 0.15), 'production_price needs to be smaller than max_price and >=0') -negative_production_price = (ut_t.create_hyperparameter_mock_json_sim_market('10', '50', '50', '80', '20', '-10', '0.15'), +negative_production_price = (ut_t.create_hyperparameter_mock_dict_sim_market(10, 50, 50, 80, 20, -10, 0.15), 'production_price needs to be smaller than max_price and >=0') -negative_max_quality = (ut_t.create_hyperparameter_mock_json_sim_market('10', '20', '15', '-80', '30', '5', '0.15'), +negative_max_quality = (ut_t.create_hyperparameter_mock_dict_sim_market(10, 20, 15, -80, 30, 5, 0.15), 'max_quality should be positive') -non_negative_storage_cost = (ut_t.create_hyperparameter_mock_json_sim_market('10', '20', '15', '80', '30', '5', '-3.5'), +non_negative_storage_cost = (ut_t.create_hyperparameter_mock_dict_sim_market(10, 20, 15, 80, 30, 5, -3.5), 'storage_cost_per_product should be non-negative') # These tests are missing a line in the config file, the import should throw a specific error message -missing_max_storage = (ut_t.remove_line(0, ut_t.create_hyperparameter_mock_json_sim_market()), 'your config is missing max_storage') -missing_episode_length = (ut_t.remove_line(1, ut_t.create_hyperparameter_mock_json_sim_market()), 'your config is missing episode_length') -missing_max_price = (ut_t.remove_line(2, ut_t.create_hyperparameter_mock_json_sim_market()), 'your config is missing max_price') -missing_max_quality = (ut_t.remove_line(3, ut_t.create_hyperparameter_mock_json_sim_market()), 'your config is missing max_quality') -missing_number_of_customers = (ut_t.remove_line(4, ut_t.create_hyperparameter_mock_json_sim_market()), +missing_max_storage = (ut_t.remove_key('max_storage', ut_t.create_hyperparameter_mock_dict_sim_market()), + 'your config is missing max_storage') +missing_episode_length = (ut_t.remove_key('episode_length', ut_t.create_hyperparameter_mock_dict_sim_market()), + 'your config is missing episode_length') +missing_max_price = (ut_t.remove_key('max_price', ut_t.create_hyperparameter_mock_dict_sim_market()), + 'your config is missing max_price') +missing_max_quality = (ut_t.remove_key('max_quality', ut_t.create_hyperparameter_mock_dict_sim_market()), + 'your config is missing max_quality') +missing_number_of_customers = (ut_t.remove_key('number_of_customers', ut_t.create_hyperparameter_mock_dict_sim_market()), 'your config is missing number_of_customers') -missing_production_price = (ut_t.remove_line(5, ut_t.create_hyperparameter_mock_json_sim_market()), +missing_production_price = (ut_t.remove_key('production_price', ut_t.create_hyperparameter_mock_dict_sim_market()), 'your config is missing production_price') -missing_storage_cost = (ut_t.remove_line(6, ut_t.create_hyperparameter_mock_json_sim_market()), +missing_storage_cost = (ut_t.remove_key('storage_cost_per_product', ut_t.create_hyperparameter_mock_dict_sim_market()), 'your config is missing storage_cost_per_product') # All pairs concerning themselves with invalid config.json values should be added to this array to get tested in test_invalid_values @@ -106,9 +111,9 @@ def test_reading_file_values(): # Test that checks that an invalid/broken config.json gets detected correctly @pytest.mark.parametrize('sim_market_json, expected_message', invalid_values_testcases) def test_invalid_values(sim_market_json, expected_message): - json = ut_t.create_hyperparameter_mock_json(sim_market=sim_market_json) - with patch('builtins.open', mock_open(read_data=json)) as mock_file: - ut_t.check_mock_file(mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict(sim_market=sim_market_json)) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) with pytest.raises(AssertionError) as assertion_message: import_config() assert expected_message in str(assertion_message.value) diff --git a/tests/test_q_learning_training.py b/tests/test_q_learning_training.py index 0d71bd31..87a3f774 100644 --- a/tests/test_q_learning_training.py +++ b/tests/test_q_learning_training.py @@ -1,3 +1,4 @@ +import json from importlib import reload from unittest.mock import mock_open, patch @@ -40,9 +41,10 @@ def import_config() -> hyperparameter_config.HyperparameterConfig: @pytest.mark.slow @pytest.mark.parametrize('market_class, agent_class', test_scenarios) def test_market_scenario(market_class, agent_class): - json = ut_t.create_hyperparameter_mock_json(rl=ut_t.create_hyperparameter_mock_json_rl(replay_start_size='500', sync_target_frames='100')) - with patch('builtins.open', mock_open(read_data=json)) as mock_file: - ut_t.check_mock_file(mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict( + rl=ut_t.create_hyperparameter_mock_dict_rl(replay_start_size=500, sync_target_frames=100))) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) config = import_config() with patch('recommerce.rl.training.SummaryWriter'): @@ -52,9 +54,10 @@ def test_market_scenario(market_class, agent_class): @pytest.mark.training @pytest.mark.slow def test_training_with_tensorboard(): - json = ut_t.create_hyperparameter_mock_json(rl=ut_t.create_hyperparameter_mock_json_rl(replay_start_size='500', sync_target_frames='100')) - with patch('builtins.open', mock_open(read_data=json)) as mock_file: - ut_t.check_mock_file(mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict( + rl=ut_t.create_hyperparameter_mock_dict_rl(replay_start_size=500, sync_target_frames=100))) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) config = import_config() market_class = linear_market.ClassicScenario agent_class = QLearningLEAgent diff --git a/tests/test_sim_market.py b/tests/test_sim_market.py index 8f035f94..24776cc6 100644 --- a/tests/test_sim_market.py +++ b/tests/test_sim_market.py @@ -18,4 +18,4 @@ def test_unique_output_dict(marketclass): market = marketclass() _, _, _, info_dict_1 = market.step(ut_t.create_mock_action(marketclass)) _, _, _, info_dict_2 = market.step(ut_t.create_mock_action(marketclass)) - assert id(info_dict_1) is not id(info_dict_2) + assert id(info_dict_1) != id(info_dict_2) diff --git a/tests/test_svg_manipulation.py b/tests/test_svg_manipulation.py index f24a3dcd..de254462 100644 --- a/tests/test_svg_manipulation.py +++ b/tests/test_svg_manipulation.py @@ -1,4 +1,6 @@ +import json import os +import pathlib from unittest.mock import mock_open, patch import pytest @@ -28,9 +30,9 @@ def test_correct_template(): assert correct_template == svg_manipulator.template_svg # run one exampleprinter and to make sure the template does not get changed - json = ut_t.create_hyperparameter_mock_json_sim_market(episode_length='3') - with patch('builtins.open', mock_open(read_data=json)) as utils_mock_file: - ut_t.check_mock_file(utils_mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict_sim_market(episode_length=3)) + with patch('builtins.open', mock_open(read_data=mock_json)) as utils_mock_file: + ut_t.check_mock_file(utils_mock_file, mock_json) # initialize all functions to be mocked with patch('recommerce.monitoring.exampleprinter.ut.write_dict_to_tensorboard'), \ patch('recommerce.monitoring.svg_manipulation.os.path.isfile') as mock_isfile, \ @@ -74,8 +76,7 @@ def test_write_dict_to_svg(): test_dict[key] = 'test' svg_manipulator.write_dict_to_svg(test_dict) correct_svg = '' - with open(os.path.join(os.path.dirname(__file__), 'test_data', 'output_test_svg.svg')) as file: - correct_svg = file.read() + correct_svg = pathlib.Path(os.path.join(os.path.dirname(__file__), 'test_data', 'output_test_svg.svg')).read_text() assert correct_svg == svg_manipulator.output_svg assert test_dict == svg_manipulator.value_dictionary @@ -193,9 +194,9 @@ def test_time_not_int(): def test_one_exampleprinter_run(): # run only three episodes to be able to reuse the correct_html - json = ut_t.create_hyperparameter_mock_json_sim_market(episode_length='3') - with patch('builtins.open', mock_open(read_data=json)) as utils_mock_file: - ut_t.check_mock_file(utils_mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict_sim_market(episode_length=3)) + with patch('builtins.open', mock_open(read_data=mock_json)) as utils_mock_file: + ut_t.check_mock_file(utils_mock_file, mock_json) # initialize all functions to be mocked with patch('recommerce.monitoring.exampleprinter.ut.write_dict_to_tensorboard'), \ patch('recommerce.monitoring.svg_manipulation.os.path.isfile') as mock_isfile, \ diff --git a/tests/test_utils.py b/tests/test_utils.py index c9dde563..c60fa8ce 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,4 @@ +import json from importlib import reload from unittest.mock import Mock, mock_open, patch @@ -31,7 +32,8 @@ def import_config() -> hyperparameter_config.HyperparameterConfig: @pytest.mark.parametrize('max_quality', testcases_shuffle_quality) def test_shuffle_quality(max_quality: int): - mock_json = ut_t.create_hyperparameter_mock_json(sim_market=ut_t.create_hyperparameter_mock_json_sim_market(max_quality=str(max_quality))) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict( + sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(max_quality=max_quality))) with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: ut_t.check_mock_file(mock_file, mock_json) import_config() @@ -277,8 +279,8 @@ def test_write_content_of_dict_to_overview_svg( episode_dictionary: dict, cumulated_dictionary: dict, expected: dict): - mock_json = (ut_t.create_hyperparameter_mock_json( - sim_market=ut_t.create_hyperparameter_mock_json_sim_market(episode_length='50', number_of_customers='20', production_price='3'))) + mock_json = json.dumps((ut_t.create_hyperparameter_mock_dict( + sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(episode_length=50, number_of_customers=20, production_price=3)))) with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: ut_t.check_mock_file(mock_file, mock_json) import_config() diff --git a/tests/test_vendors.py b/tests/test_vendors.py index 4031224a..6a8ce890 100644 --- a/tests/test_vendors.py +++ b/tests/test_vendors.py @@ -1,3 +1,4 @@ +import json from importlib import reload from unittest.mock import mock_open, patch @@ -106,10 +107,10 @@ def test_fixed_price_agent_observation_policy_pairs(agent, expected_result): @pytest.mark.parametrize('state, expected_prices', storage_evaluation_testcases) def test_storage_evaluation(state, expected_prices): - json = ut_t.create_hyperparameter_mock_json( - sim_market=ut_t.create_hyperparameter_mock_json_sim_market(max_price='10', production_price='2')) - with patch('builtins.open', mock_open(read_data=json)) as mock_file: - ut_t.check_mock_file(mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict( + sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(max_price=10, production_price=2))) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) import_config() agent = circular_vendors.RuleBasedCEAgent() assert expected_prices == agent.policy(state) @@ -125,20 +126,20 @@ def test_storage_evaluation(state, expected_prices): @pytest.mark.parametrize('state, expected_prices', storage_evaluation_with_rebuy_price_testcases) def test_storage_evaluation_with_rebuy_price(state, expected_prices): - json = ut_t.create_hyperparameter_mock_json( - sim_market=ut_t.create_hyperparameter_mock_json_sim_market(max_price='10', production_price='2')) - with patch('builtins.open', mock_open(read_data=json)) as mock_file: - ut_t.check_mock_file(mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict( + sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(max_price=10, production_price=2))) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) import_config() agent = circular_vendors.RuleBasedCERebuyAgent() assert expected_prices == agent.policy(state) def test_prices_are_not_higher_than_allowed(): - json = ut_t.create_hyperparameter_mock_json( - sim_market=ut_t.create_hyperparameter_mock_json_sim_market(max_price='10', production_price='9')) - with patch('builtins.open', mock_open(read_data=json)) as mock_file: - ut_t.check_mock_file(mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict( + sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(max_price=10, production_price=9))) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) import_config() test_agent = circular_vendors.RuleBasedCEAgent() assert (9, 9) >= test_agent.policy([50, 60]) @@ -160,10 +161,10 @@ def random_offer(): # TODO: Update this test for all current competitors @pytest.mark.parametrize('competitor_class, state', policy_testcases) def test_policy(competitor_class, state): - json = ut_t.create_hyperparameter_mock_json( - sim_market=ut_t.create_hyperparameter_mock_json_sim_market(max_price='10', production_price='2')) - with patch('builtins.open', mock_open(read_data=json)) as mock_file: - ut_t.check_mock_file(mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict( + sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(max_price=10, production_price=2))) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) import_config() competitor = competitor_class() @@ -181,10 +182,10 @@ def test_policy(competitor_class, state): # TODO: Update this test for all current competitors @pytest.mark.parametrize('competitor_class, state', policy_plus_one_testcases) def test_policy_plus_one(competitor_class, state): - json = ut_t.create_hyperparameter_mock_json( - sim_market=ut_t.create_hyperparameter_mock_json_sim_market(max_price='10', production_price='2')) - with patch('builtins.open', mock_open(read_data=json)) as mock_file: - ut_t.check_mock_file(mock_file, json) + mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict( + sim_market=ut_t.create_hyperparameter_mock_dict_sim_market(max_price=10, production_price=2))) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) import_config() competitor = competitor_class() diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 0000350f..d6899012 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -1,100 +1,103 @@ -import os from typing import Tuple, Union import recommerce.market.circular.circular_sim_market as circular_market import recommerce.market.linear.linear_sim_market as linear_market -from recommerce.configuration.path_manager import PathManager -def create_hyperparameter_mock_json_rl(gamma='0.99', - batch_size='32', - replay_size='100000', - learning_rate='1e-6', - sync_target_frames='1000', - replay_start_size='10000', - epsilon_decay_last_frame='75000', - epsilon_start='1.0', - epsilon_final='0.1') -> str: +def create_hyperparameter_mock_dict_rl(gamma: float = 0.99, + batch_size: int = 32, + replay_size: int = 100000, + learning_rate: float = 1e-6, + sync_target_frames: int = 1000, + replay_start_size: int = 10000, + epsilon_decay_last_frame: int = 75000, + epsilon_start: float = 1.0, + epsilon_final: float = 0.1) -> dict: """ - Create a string in JSON format that can be used to mock the config_rl.json file. + Create dictionary that can be used to mock the rl part of the hyperparameter_config.json file by calling json.dumps() on it. Args: - gamma (str, optional): Defaults to '0.99'. - batch_size (str, optional): Defaults to '32'. - replay_size (str, optional): Defaults to '100000'. - learning_rate (str, optional): Defaults to '1e-6'. - sync_target_frames (str, optional): Defaults to '1000'. - replay_start_size (str, optional): Defaults to '10000'. - epsilon_decay_last_frame (str, optional): Defaults to '75000'. - epsilon_start (str, optional): Defaults to '1.0'. - epsilon_final (str, optional): Defaults to '0.1'. + gamma (float, optional): Defaults to 0.99. + batch_size (int, optional): Defaults to 32. + replay_size (int, optional): Defaults to 100000. + learning_rate (float, optional): Defaults to 1e-6. + sync_target_frames (int, optional): Defaults to 1000. + replay_start_size (int, optional): Defaults to 10000. + epsilon_decay_last_frame (int, optional): Defaults to 75000. + epsilon_start (float, optional): Defaults to 1.0. + epsilon_final (float, optional): Defaults to 0.1. Returns: - str: A string in JSON format. - """ - return '{\n\t\t"gamma": ' + gamma + ',\n' + \ - '\t\t"batch_size": ' + batch_size + ',\n' + \ - '\t\t"replay_size": ' + replay_size + ',\n' + \ - '\t\t"learning_rate": ' + learning_rate + ',\n' + \ - '\t\t"sync_target_frames": ' + sync_target_frames + ',\n' + \ - '\t\t"replay_start_size": ' + replay_start_size + ',\n' + \ - '\t\t"epsilon_decay_last_frame": ' + epsilon_decay_last_frame + ',\n' + \ - '\t\t"epsilon_start": ' + epsilon_start + ',\n' + \ - '\t\t"epsilon_final": ' + epsilon_final + '\n' + \ - '\t}' - - -def create_hyperparameter_mock_json_sim_market( - max_storage='20', - episode_length='20', - max_price='15', - max_quality='100', - number_of_customers='30', - production_price='5', - storage_cost_per_product='0.3') -> str: - """ - Create a string in JSON format that can be used to mock the config_sim_market.json file. + dict: The mock dictionary. + """ + return { + 'gamma': gamma, + 'batch_size': batch_size, + 'replay_size': replay_size, + 'learning_rate': learning_rate, + 'sync_target_frames': sync_target_frames, + 'replay_start_size': replay_start_size, + 'epsilon_decay_last_frame': epsilon_decay_last_frame, + 'epsilon_start': epsilon_start, + 'epsilon_final': epsilon_final, + } + + +def create_hyperparameter_mock_dict_sim_market( + max_storage: int = 20, + episode_length: int = 20, + max_price: int = 15, + max_quality: int = 100, + number_of_customers: int = 30, + production_price: int = 5, + storage_cost_per_product: float = 0.3) -> dict: + """ + Create dictionary that can be used to mock the sim_market part of the hyperparameter_config.json file by calling json.dumps() on it. Args: - max_storage (str, optional): Defaults to '20'. - episode_length (str, optional): Defaults to '20'. - max_price (str, optional): Defaults to '15'. - max_quality (str, optional): Defaults to '100'. - number_of_customers (str, optional): Defaults to '30'. - production_price (str, optional): Defaults to '5'. - storage_cost_per_product (str, optional): Defaults to '0.3'. + max_storage (int, optional): Defaults to 20. + episode_length (int, optional): Defaults to 20. + max_price (int, optional): Defaults to 15. + max_quality (int, optional): Defaults to 100. + number_of_customers (int, optional): Defaults to 30. + production_price (int, optional): Defaults to 5. + storage_cost_per_product (float, optional): Defaults to 0.3. Returns: - str: A string in JSON format. + dict: The mock dictionary. """ - return '{\n\t\t"max_storage": ' + max_storage + ',\n' + \ - '\t\t"episode_length": ' + episode_length + ',\n' + \ - '\t\t"max_price": ' + max_price + ',\n' + \ - '\t\t"max_quality": ' + max_quality + ',\n' + \ - '\t\t"number_of_customers": ' + number_of_customers + ',\n' + \ - '\t\t"production_price": ' + production_price + ',\n' + \ - '\t\t"storage_cost_per_product": ' + storage_cost_per_product + '\n' + \ - '\t}' + return { + 'max_storage': max_storage, + 'episode_length': episode_length, + 'max_price': max_price, + 'max_quality': max_quality, + 'number_of_customers': number_of_customers, + 'production_price': production_price, + 'storage_cost_per_product': storage_cost_per_product, + } -def create_hyperparameter_mock_json(rl: str = create_hyperparameter_mock_json_rl(), - sim_market: str = create_hyperparameter_mock_json_sim_market()) -> str: +def create_hyperparameter_mock_dict(rl: dict = create_hyperparameter_mock_dict_rl(), + sim_market: dict = create_hyperparameter_mock_dict_sim_market()) -> dict: """ - Create a mock json in the format of the hyperparameter_config.json. + Create a dictionary in the format of the hyperparameter_config.json. + Call json.dumps() on the return value of this to mock the json file. Args: - rl (str, optional): The string that should be used for the rl-part. Defaults to create_hyperparameter_mock_json_rl(). - sim_market (str, optional): The string that should be used for the sim_market-part. - Defaults to create_hyperparameter_mock_json_sim_market(). + rl (dict, optional): The dictionary that should be used for the rl-part. Defaults to create_hyperparameter_mock_dict_rl(). + sim_market (dict, optional): The dictionary that should be used for the sim_market-part. + Defaults to create_hyperparameter_mock_dict_sim_market(). Returns: - str: The mock json. + dict: The mock dictionary. """ - return '{\n' + '\t"rl": ' + rl + ',\n' + '\t"sim_market": ' + sim_market + '\n}' + return { + 'rl': rl, + 'sim_market': sim_market + } -def create_environment_mock_dict( - task: str = 'agent_monitoring', +def create_environment_mock_dict(task: str = 'agent_monitoring', enable_live_draw: bool = False, episodes: int = 10, plot_interval: int = 5, @@ -110,18 +113,20 @@ def create_environment_mock_dict( plot_interval (int, optional): How often plots should be drawn. Defaults to 5. marketplace (str, optional): What marketplace to run on. Defaults to "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopolyScenario". - agents (dict, optional): What agents to use. - Defaults to {"Fixed CE Rebuy Agent": {"class": "market.vendors.FixedPriceCERebuyAgent"}}. + agents (dict, optional): What agents to use. Defaults to + [{'name': 'Fixed CE Rebuy Agent', 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', 'argument': ''}]. Returns: dict: The mock dictionary. """ if agents is None: - agents = { - 'Fixed CE Rebuy Agent': { - 'class': 'market.vendors.FixedPriceCERebuyAgent' + agents = [ + { + 'name': 'Fixed CE Rebuy Agent', + 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + 'argument': '' } - } + ] return { 'task': task, @@ -133,35 +138,63 @@ def create_environment_mock_dict( } -def check_mock_file(mock_file, json) -> None: +def create_combined_mock_dict(hyperparameter: dict or None = create_hyperparameter_mock_dict(), + environment: dict or None = create_environment_mock_dict()) -> dict: + """ + Create a mock dictionary in the format of a configuration file with both a hyperparameter and environment part. + If any of the two parameters is `None`, leave that key out of the resulting dictionary. + + Args: + hyperparameter (dict | None, optional): The hyperparameter part of the combined config. Defaults to create_hyperparameter_mock_dict(). + environment (dict | None, optional): The environment part of the combined config. Defaults to create_environment_mock_dict(). + + Returns: + dict: The mock dictionary. + """ + if hyperparameter is None and environment is None: + return {} + elif hyperparameter is None: + return { + 'environment': environment + } + elif environment is None: + return { + 'hyperparameter': hyperparameter + } + return { + 'hyperparameter': hyperparameter, + 'environment': environment + } + + +def check_mock_file(mock_file, mocked_file_content) -> None: """ - Confirm that a mock JSON for the config.json is read correctly. + Confirm that a mock JSON is read correctly. Args: mock_file (unittest.mock.MagicMock): The mocked file. - json (str): The mock JSON string to be checked. + mocked_file_content (str): The mocked_file_content to be checked. """ - path = os.path.join(PathManager.user_path, 'config.json') - assert open(path).read() == json, 'the mock did not work correctly, as the read file was not equal to the set mock-json' + path = 'some_path' + with open(path) as file: + assert file.read() == mocked_file_content, \ + 'the mock did not work correctly, as the read file was not equal to the set mocked_file_content' mock_file.assert_called_with(path) -def remove_line(number, json) -> str: +def remove_key(key: str, original_dict: dict) -> dict: """ - Remove the specified line from a mock JSON string. + Remove the specified key from a dictionary and return the dictionary. Args: - number (int): The line that should be removed. - json (str): The JSON string from which to remove the line. + key (str): The key that should be removed. + json (dict): The dictionary from which to remove the line. Returns: - str: The JSON string with the missing line. + dict: The dictionary without the key. """ - lines = json.split('\n') - final_lines = lines[:number + 1] - final_lines += lines[number + 2:] - final_lines[-2] = final_lines[-2].replace(',', '') - return '\n'.join(final_lines) + original_dict.pop(key) + return original_dict def create_mock_rewards(num_entries) -> list: