Skip to content

Commit

Permalink
reintroduced test_hyperparameter_config_rl
Browse files Browse the repository at this point in the history
  • Loading branch information
jannikgro committed May 26, 2022
1 parent 7cef3e7 commit b6a4945
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 74 deletions.
2 changes: 1 addition & 1 deletion recommerce/configuration/hyperparameter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,5 @@ def load(cls, filename: str) -> AttrDict:
assert issubclass(config['class'], JSONConfigurable), f"The class {config['class']} must be a subclass of JSONConfigurable."
print(config)
HyperparameterConfigValidator.validate_config(config)
config.pop_key('class')
config.pop('class')
return AttrDict(config)
148 changes: 77 additions & 71 deletions tests/test_hyperparameter_config_rl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# import json
import json
import os
# from importlib import reload
# from unittest.mock import mock_open, patch
from unittest.mock import mock_open, patch

# import pytest
# import utils_tests as ut_t
import pytest
import utils_tests as ut_t

# import recommerce.configuration.hyperparameter_config as hyperparameter_config
import recommerce.configuration.hyperparameter_config as hyperparameter_config
from recommerce.configuration.path_manager import PathManager

rl_config_file = os.path.join(PathManager.user_path, 'configuration_files', 'rl_config.json')


# def teardown_module(module):
Expand Down Expand Up @@ -85,69 +89,71 @@
# assert config.learning_rate == 1e-4


# # The following variables are input mock-json strings for the test_invalid_values test
# # These tests have invalid values in their input file, the import should throw a specific error message
# learning_rate_larger_one = (ut_t.create_hyperparameter_mock_dict_rl(learning_rate=1.5),
# 'learning_rate should be between 0 and 1 (excluded)')
# negative_learning_rate = (ut_t.create_hyperparameter_mock_dict_rl(learning_rate=0), 'learning_rate should be between 0 and 1 (excluded)')
# large_gamma = (ut_t.create_hyperparameter_mock_dict_rl(gamma=1.0), 'gamma should be between 0 (included) and 1 (excluded)')
# negative_gamma = ((ut_t.create_hyperparameter_mock_dict_rl(gamma=-1.0), 'gamma should be between 0 (included) and 1 (excluded)'))
# negative_batch_size = (ut_t.create_hyperparameter_mock_dict_rl(batch_size=-5), 'batch_size should be greater than 0')
# negative_replay_size = (ut_t.create_hyperparameter_mock_dict_rl(replay_size=-5),
# 'replay_size should be greater than 0')
# negative_sync_target_frames = (ut_t.create_hyperparameter_mock_dict_rl(sync_target_frames=-5),
# 'sync_target_frames should be greater than 0')
# negative_replay_start_size = (ut_t.create_hyperparameter_mock_dict_rl(replay_start_size=-5), 'replay_start_size should be greater than 0')
# negative_epsilon_decay_last_frame = (ut_t.create_hyperparameter_mock_dict_rl(epsilon_decay_last_frame=-5),
# 'epsilon_decay_last_frame should not be negative')


# # These tests are missing a line in the config file, the import should throw a specific error message
# missing_gamma = (ut_t.remove_key('gamma', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing gamma')
# missing_batch_size = (ut_t.remove_key('batch_size', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing batch_size')
# missing_replay_size = (ut_t.remove_key('replay_size', ut_t.create_hyperparameter_mock_dict_rl()), 'your config_rl is missing replay_size')
# missing_learning_rate = (ut_t.remove_key('learning_rate', ut_t.create_hyperparameter_mock_dict_rl()),
# 'your config_rl is missing learning_rate')
# missing_sync_target_frames = (ut_t.remove_key('sync_target_frames', ut_t.create_hyperparameter_mock_dict_rl()),
# 'your config_rl is missing sync_target_frames')
# missing_replay_start_size = (ut_t.remove_key('replay_start_size', ut_t.create_hyperparameter_mock_dict_rl()),
# 'your config_rl is missing replay_start_size')
# missing_epsilon_decay_last_frame = (ut_t.remove_key('epsilon_decay_last_frame', ut_t.create_hyperparameter_mock_dict_rl()),
# 'your config_rl is missing epsilon_decay_last_frame')
# missing_epsilon_start = (ut_t.remove_key('epsilon_start', ut_t.create_hyperparameter_mock_dict_rl()),
# 'your config_rl is missing epsilon_start')
# missing_epsilon_final = (ut_t.remove_key('epsilon_final', ut_t.create_hyperparameter_mock_dict_rl()),
# 'your config_rl is missing epsilon_final')


# invalid_values_testcases = [
# missing_gamma,
# missing_batch_size,
# missing_replay_size,
# missing_learning_rate,
# missing_sync_target_frames,
# missing_replay_start_size,
# missing_epsilon_decay_last_frame,
# missing_epsilon_start,
# missing_epsilon_final,
# learning_rate_larger_one,
# negative_learning_rate,
# large_gamma,
# negative_gamma,
# negative_batch_size,
# negative_replay_size,
# negative_sync_target_frames,
# negative_replay_start_size,
# negative_epsilon_decay_last_frame
# ]


# # Test that checks that an invalid/broken config.json gets detected correctly
# @pytest.mark.parametrize('rl_json, expected_message', invalid_values_testcases)
# def test_invalid_values(rl_json, expected_message):
# mock_json = json.dumps(ut_t.create_hyperparameter_mock_dict(rl=rl_json))
# with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file:
# ut_t.check_mock_file(mock_file, mock_json)
# with pytest.raises(AssertionError) as assertion_message:
# hyperparameter_config.HyperparameterConfigLoader.load('hyperparameter_config')
# assert expected_message in str(assertion_message.value)
# The following variables are input mock-json strings for the test_invalid_values test
# These tests have invalid values in their input file, the import should throw a specific error message
negative_learning_rate = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'learning_rate', 0.0),
'learning_rate should be positive')
large_gamma = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'gamma', 1.1),
'gamma should be between 0 (included) and 1 (included)')
negative_gamma = ((ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'gamma', -1.0),
'gamma should be between 0 (included) and 1 (included)'))
negative_batch_size = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'batch_size', -5),
'batch_size should be positive')
negative_replay_size = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'replay_size', -5),
'replay_size should be positive')
negative_sync_target_frames = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'sync_target_frames', -5),
'sync_target_frames should be positive')
negative_replay_start_size = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'replay_start_size', -5),
'replay_start_size should be positive')
negative_epsilon_decay_last_frame = (ut_t.replace_field_in_dict(ut_t.load_json(rl_config_file), 'epsilon_decay_last_frame', -5),
'epsilon_decay_last_frame should be positive')


# These tests are missing a line in the config file, the import should throw a specific error message
missing_gamma = (ut_t.remove_key('gamma', ut_t.load_json(rl_config_file)), 'your config is missing gamma')
missing_batch_size = (ut_t.remove_key('batch_size', ut_t.load_json(rl_config_file)), 'your config is missing batch_size')
missing_replay_size = (ut_t.remove_key('replay_size', ut_t.load_json(rl_config_file)), 'your config is missing replay_size')
missing_learning_rate = (ut_t.remove_key('learning_rate', ut_t.load_json(rl_config_file)),
'your config is missing learning_rate')
missing_sync_target_frames = (ut_t.remove_key('sync_target_frames', ut_t.load_json(rl_config_file)),
'your config is missing sync_target_frames')
missing_replay_start_size = (ut_t.remove_key('replay_start_size', ut_t.load_json(rl_config_file)),
'your config is missing replay_start_size')
missing_epsilon_decay_last_frame = (ut_t.remove_key('epsilon_decay_last_frame', ut_t.load_json(rl_config_file)),
'your config is missing epsilon_decay_last_frame')
missing_epsilon_start = (ut_t.remove_key('epsilon_start', ut_t.load_json(rl_config_file)),
'your config is missing epsilon_start')
missing_epsilon_final = (ut_t.remove_key('epsilon_final', ut_t.load_json(rl_config_file)),
'your config is missing epsilon_final')


invalid_values_testcases = [
missing_gamma,
missing_batch_size,
missing_replay_size,
missing_learning_rate,
missing_sync_target_frames,
missing_replay_start_size,
missing_epsilon_decay_last_frame,
missing_epsilon_start,
missing_epsilon_final,
negative_learning_rate,
large_gamma,
negative_gamma,
negative_batch_size,
negative_replay_size,
negative_sync_target_frames,
negative_replay_start_size,
negative_epsilon_decay_last_frame
]


# Test that checks that an invalid/broken config.json gets detected correctly
@pytest.mark.parametrize('rl_json, expected_message', invalid_values_testcases)
def test_invalid_values(rl_json, expected_message):
mock_json = json.dumps(rl_json)
with patch('builtins.open', mock_open(read_data=mock_json)) as mock_file:
ut_t.check_mock_file(mock_file, mock_json)
with pytest.raises(AssertionError) as assertion_message:
hyperparameter_config.HyperparameterConfigLoader.load('hyperparameter_config')
assert expected_message in str(assertion_message.value)
34 changes: 32 additions & 2 deletions tests/utils_tests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# import json
# from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader
import json
from typing import Tuple, Union

import recommerce.market.circular.circular_sim_market as circular_market
Expand All @@ -9,8 +11,6 @@

# from attrdict import AttrDict

# from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader


def create_hyperparameter_mock_dict_rl(gamma: float = 0.99,
batch_size: int = 32,
Expand Down Expand Up @@ -85,6 +85,36 @@ def create_hyperparameter_mock_dict_sim_market(
}


def load_json(path: str):
"""
Load a json file.
Args:
path (str): The path to the json file.
Returns:
dict: The json file as a dictionary.
"""
with open(path) as file:
return json.load(file)


def replace_field_in_dict(initial_dict: dict, key: str, value: Union[str, int, float]) -> dict:
"""
Replace a field in a dictionary with a new value.
Args:
initial_dict (dict): The dictionary in which to replace the field.
key (str): The key of the field to be replaced.
value (Union[str, int, float]): The new value of the field.
Returns:
dict: The dictionary with the field replaced.
"""
initial_dict[key] = value
return initial_dict


# def create_hyperparameter_mock_dict(rl: dict = create_hyperparameter_mock_dict_rl(),
# sim_market: dict = create_hyperparameter_mock_dict_sim_market()) -> dict:
# """
Expand Down

0 comments on commit b6a4945

Please sign in to comment.