From b6a49450f5ce017746b0405c37b431bf6e7f55e8 Mon Sep 17 00:00:00 2001 From: Jan Niklas Groeneveld Date: Thu, 26 May 2022 15:52:19 +0200 Subject: [PATCH] reintroduced test_hyperparameter_config_rl --- .../configuration/hyperparameter_config.py | 2 +- tests/test_hyperparameter_config_rl.py | 148 +++++++++--------- tests/utils_tests.py | 34 +++- 3 files changed, 110 insertions(+), 74 deletions(-) diff --git a/recommerce/configuration/hyperparameter_config.py b/recommerce/configuration/hyperparameter_config.py index c27bc79e..f9f9b50f 100644 --- a/recommerce/configuration/hyperparameter_config.py +++ b/recommerce/configuration/hyperparameter_config.py @@ -84,5 +84,5 @@ def load(cls, filename: str) -> AttrDict: assert issubclass(config['class'], JSONConfigurable), f"The class {config['class']} must be a subclass of JSONConfigurable." print(config) HyperparameterConfigValidator.validate_config(config) - config.pop_key('class') + config.pop('class') return AttrDict(config) diff --git a/tests/test_hyperparameter_config_rl.py b/tests/test_hyperparameter_config_rl.py index efba93b4..fd19e7a0 100644 --- a/tests/test_hyperparameter_config_rl.py +++ b/tests/test_hyperparameter_config_rl.py @@ -1,11 +1,15 @@ -# import json +import json +import os # from importlib import reload -# from unittest.mock import mock_open, patch +from unittest.mock import mock_open, patch -# import pytest -# import utils_tests as ut_t +import pytest +import utils_tests as ut_t -# import recommerce.configuration.hyperparameter_config as hyperparameter_config +import recommerce.configuration.hyperparameter_config as hyperparameter_config +from recommerce.configuration.path_manager import PathManager + +rl_config_file = os.path.join(PathManager.user_path, 'configuration_files', 'rl_config.json') # def teardown_module(module): @@ -85,69 +89,71 @@ # assert config.learning_rate == 1e-4 -# # The following variables are input mock-json strings for the test_invalid_values test -# # These tests have invalid values in their input file, the import should throw a specific error message -# learning_rate_larger_one = (ut_t.create_hyperparameter_mock_dict_rl(learning_rate=1.5), -# 'learning_rate should be between 0 and 1 (excluded)') -# negative_learning_rate = (ut_t.create_hyperparameter_mock_dict_rl(learning_rate=0), 'learning_rate should be between 0 and 1 (excluded)') -# large_gamma = (ut_t.create_hyperparameter_mock_dict_rl(gamma=1.0), 'gamma should be between 0 (included) and 1 (excluded)') -# negative_gamma = ((ut_t.create_hyperparameter_mock_dict_rl(gamma=-1.0), 'gamma should be between 0 (included) and 1 (excluded)')) -# negative_batch_size = (ut_t.create_hyperparameter_mock_dict_rl(batch_size=-5), 'batch_size should be greater than 0') -# negative_replay_size = (ut_t.create_hyperparameter_mock_dict_rl(replay_size=-5), -# 'replay_size should be greater than 0') -# negative_sync_target_frames = (ut_t.create_hyperparameter_mock_dict_rl(sync_target_frames=-5), -# 'sync_target_frames should be greater than 0') -# negative_replay_start_size = (ut_t.create_hyperparameter_mock_dict_rl(replay_start_size=-5), 'replay_start_size should be greater than 0') -# negative_epsilon_decay_last_frame = (ut_t.create_hyperparameter_mock_dict_rl(epsilon_decay_last_frame=-5), -# 'epsilon_decay_last_frame should not be negative') - - -# # These tests are missing a line in the config file, the import should throw a specific error message -# missing_gamma = (ut_t.remove_key('gamma', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing gamma') -# missing_batch_size = (ut_t.remove_key('batch_size', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing batch_size') -# missing_replay_size = (ut_t.remove_key('replay_size', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing replay_size') -# missing_learning_rate = (ut_t.remove_key('learning_rate', ut_t.create_hyperparameter_mock_dict_rl()), -# 'your config_rl is missing learning_rate') -# missing_sync_target_frames = (ut_t.remove_key('sync_target_frames', ut_t.create_hyperparameter_mock_dict_rl()), -# 'your config_rl is missing sync_target_frames') -# missing_replay_start_size = (ut_t.remove_key('replay_start_size', ut_t.create_hyperparameter_mock_dict_rl()), -# 'your config_rl is missing replay_start_size') -# missing_epsilon_decay_last_frame = (ut_t.remove_key('epsilon_decay_last_frame', ut_t.create_hyperparameter_mock_dict_rl()), -# 'your config_rl is missing epsilon_decay_last_frame') -# missing_epsilon_start = (ut_t.remove_key('epsilon_start', ut_t.create_hyperparameter_mock_dict_rl()), -# 'your config_rl is missing epsilon_start') -# missing_epsilon_final = (ut_t.remove_key('epsilon_final', ut_t.create_hyperparameter_mock_dict_rl()), -# 'your config_rl is missing epsilon_final') - - -# invalid_values_testcases = [ -# missing_gamma, -# missing_batch_size, -# missing_replay_size, -# missing_learning_rate, -# missing_sync_target_frames, -# missing_replay_start_size, -# missing_epsilon_decay_last_frame, -# missing_epsilon_start, -# missing_epsilon_final, -# learning_rate_larger_one, -# negative_learning_rate, -# large_gamma, -# negative_gamma, -# negative_batch_size, -# negative_replay_size, -# negative_sync_target_frames, -# negative_replay_start_size, -# negative_epsilon_decay_last_frame -# ] - - -# # Test that checks that an invalid/broken config.json gets detected correctly -# @pytest.mark.parametrize('rl_json, expected_message', invalid_values_testcases) -# def test_invalid_values(rl_json, expected_message): -# mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict(rl=rl_json)) -# with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: -# ut_t.check_mock_file(mock_file, mock_json) -# with pytest.raises(AssertionError) as assertion_message: -# hyperparameter_config.HyperparameterConfigLoader.load('hyperparameter_config') -# assert expected_message in str(assertion_message.value) +# The following variables are input mock-json strings for the test_invalid_values test +# These tests have invalid values in their input file, the import should throw a specific error message +negative_learning_rate = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'learning_rate', 0.0), + 'learning_rate should be positive') +large_gamma = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'gamma', 1.1), + 'gamma should be between 0 (included) and 1 (included)') +negative_gamma = ((ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'gamma', -1.0), + 'gamma should be between 0 (included) and 1 (included)')) +negative_batch_size = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'batch_size', -5), + 'batch_size should be positive') +negative_replay_size = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'replay_size', -5), + 'replay_size should be positive') +negative_sync_target_frames = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'sync_target_frames', -5), + 'sync_target_frames should be positive') +negative_replay_start_size = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'replay_start_size', -5), + 'replay_start_size should be positive') +negative_epsilon_decay_last_frame = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'epsilon_decay_last_frame', -5), + 'epsilon_decay_last_frame should be positive') + + +# These tests are missing a line in the config file, the import should throw a specific error message +missing_gamma = (ut_t.remove_key('gamma', ut_t.load_json(rl_config_file)), 'your config is missing gamma') +missing_batch_size = (ut_t.remove_key('batch_size', ut_t.load_json(rl_config_file)), 'your config is missing batch_size') +missing_replay_size = (ut_t.remove_key('replay_size', ut_t.load_json(rl_config_file)), 'your config is missing replay_size') +missing_learning_rate = (ut_t.remove_key('learning_rate', ut_t.load_json(rl_config_file)), + 'your config is missing learning_rate') +missing_sync_target_frames = (ut_t.remove_key('sync_target_frames', ut_t.load_json(rl_config_file)), + 'your config is missing sync_target_frames') +missing_replay_start_size = (ut_t.remove_key('replay_start_size', ut_t.load_json(rl_config_file)), + 'your config is missing replay_start_size') +missing_epsilon_decay_last_frame = (ut_t.remove_key('epsilon_decay_last_frame', ut_t.load_json(rl_config_file)), + 'your config is missing epsilon_decay_last_frame') +missing_epsilon_start = (ut_t.remove_key('epsilon_start', ut_t.load_json(rl_config_file)), + 'your config is missing epsilon_start') +missing_epsilon_final = (ut_t.remove_key('epsilon_final', ut_t.load_json(rl_config_file)), + 'your config is missing epsilon_final') + + +invalid_values_testcases = [ + missing_gamma, + missing_batch_size, + missing_replay_size, + missing_learning_rate, + missing_sync_target_frames, + missing_replay_start_size, + missing_epsilon_decay_last_frame, + missing_epsilon_start, + missing_epsilon_final, + negative_learning_rate, + large_gamma, + negative_gamma, + negative_batch_size, + negative_replay_size, + negative_sync_target_frames, + negative_replay_start_size, + negative_epsilon_decay_last_frame +] + + +# Test that checks that an invalid/broken config.json gets detected correctly +@pytest.mark.parametrize('rl_json, expected_message', invalid_values_testcases) +def test_invalid_values(rl_json, expected_message): + mock_json = json.dumps(rl_json) + with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file: + ut_t.check_mock_file(mock_file, mock_json) + with pytest.raises(AssertionError) as assertion_message: + hyperparameter_config.HyperparameterConfigLoader.load('hyperparameter_config') + assert expected_message in str(assertion_message.value) diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 7233f96e..b645a771 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -1,4 +1,6 @@ # import json +# from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader +import json from typing import Tuple, Union import recommerce.market.circular.circular_sim_market as circular_market @@ -9,8 +11,6 @@ # from attrdict import AttrDict -# from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader - def create_hyperparameter_mock_dict_rl(gamma: float = 0.99, batch_size: int = 32, @@ -85,6 +85,36 @@ def create_hyperparameter_mock_dict_sim_market( } +def load_json(path: str): + """ + Load a json file. + + Args: + path (str): The path to the json file. + + Returns: + dict: The json file as a dictionary. + """ + with open(path) as file: + return json.load(file) + + +def replace_field_in_dict(initial_dict: dict, key: str, value: Union[str, int, float]) -> dict: + """ + Replace a field in a dictionary with a new value. + + Args: + initial_dict (dict): The dictionary in which to replace the field. + key (str): The key of the field to be replaced. + value (Union[str, int, float]): The new value of the field. + + Returns: + dict: The dictionary with the field replaced. + """ + initial_dict[key] = value + return initial_dict + + # def create_hyperparameter_mock_dict(rl: dict = create_hyperparameter_mock_dict_rl(), # sim_market: dict = create_hyperparameter_mock_dict_sim_market()) -> dict: # """