Skip to content

Commit

Permalink
Fixed config validation for webserver
Browse files Browse the repository at this point in the history
  • Loading branch information
NikkelM committed May 27, 2022
1 parent 4107aaf commit c90a8d6
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 32 deletions.
71 changes: 43 additions & 28 deletions recommerce/configuration/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,32 @@ def validate_config(config: dict, config_is_final: bool) -> tuple:
else:
# try to split the config. If any keys are unknown, an AssertionError will be thrown
hyperparameter_config, environment_config = split_mixed_config(config)
# then validate that all given values have the correct types
# check_config_types(hyperparameter_config, environment_config, config_is_final)

# Following PR #485, all hyperparameter-configs need a class attribute from which we get the required/allowed fields
if 'rl' in hyperparameter_config:
assert 'class' in hyperparameter_config['rl'], 'You need to specify a class in the rl_config'
hyperparameter_config['rl']['class'] = QLearningAgent # This is a dirty fix
HyperparameterConfigValidator.validate_config(hyperparameter_config['rl'])
hyperparameter_config['rl'].pop('class')

if 'sim_market' in hyperparameter_config:
assert 'class' in hyperparameter_config['sim_market'], 'You need to specify a class in the market_config'
hyperparameter_config['sim_market']['class'] = CircularEconomyRebuyPriceDuopoly # This is a dirty fix
HyperparameterConfigValidator.validate_config(hyperparameter_config['sim_market'])

# then validate that all given values have the correct types
check_config_types(hyperparameter_config, environment_config, config_is_final)

# if 'rl' in hyperparameter_config:
# hyperparameter_config['rl']['class'] = QLearningAgent # This is a dirty fix
# HyperparameterConfigValidator.validate_config(hyperparameter_config['rl'])
# hyperparameter_config['rl'].pop('class')
# if 'sim_market' in hyperparameter_config:
# hyperparameter_config['sim_market']['class'] = CircularEconomyRebuyPriceDuopoly # This is a dirty fix
# HyperparameterConfigValidator.validate_config(hyperparameter_config['sim_market'])
# hyperparameter_config['sim_market'].pop('class')

if 'rl' in hyperparameter_config:
hyperparameter_config['rl'].pop('class')

if 'sim_market' in hyperparameter_config:
hyperparameter_config['sim_market'].pop('class')

return True, (hyperparameter_config, environment_config)
Expand Down Expand Up @@ -117,29 +133,28 @@ def split_mixed_config(config: dict) -> tuple:
return hyperparameter_config, environment_config


# def check_config_types(hyperparameter_config: dict, environment_config: dict, must_contain: bool = False) -> None:
# """
# Utility function that checks (incomplete) config dictionaries for their correct types.

# Args:
# hyperparameter_config (dict): The config containing hyperparameter_config-keys.
# environment_config (dict): The config containing environment_config-keys.
# must_contain (bool): Whether or not the configuration should contain all required keys.

# Raises:
# AssertionError: If one of the values has the wrong type.
# """
# # check types for hyperparameter_config
# # @NikkelM Why was this here?
# # HyperparameterConfigValidator.check_types(hyperparameter_config, 'top-dict', must_contain)
# if 'rl' in hyperparameter_config:
# HyperparameterConfigValidator.check_types(hyperparameter_config['rl'], 'rl', must_contain)
# if 'sim_market' in hyperparameter_config:
# HyperparameterConfigValidator.check_types(hyperparameter_config['sim_market'], 'sim_market', must_contain)

# # check types for environment_config
# task = environment_config['task'] if must_contain else 'None'
# EnvironmentConfig.check_types(environment_config, task, False, must_contain)
def check_config_types(hyperparameter_config: dict, environment_config: dict, must_contain: bool = False) -> None:
"""
Utility function that checks (incomplete) config dictionaries for their correct types.
Args:
hyperparameter_config (dict): The config containing hyperparameter_config-keys.
environment_config (dict): The config containing environment_config-keys.
must_contain (bool): Whether or not the configuration should contain all required keys.
Raises:
AssertionError: If one of the values has the wrong type.
"""
if 'rl' in hyperparameter_config:
HyperparameterConfigValidator._check_types(hyperparameter_config['rl'],
hyperparameter_config['rl']['class'].get_configurable_fields(), must_contain)
if 'sim_market' in hyperparameter_config:
HyperparameterConfigValidator._check_types(hyperparameter_config['sim_market'],
hyperparameter_config['sim_market']['class'].get_configurable_fields(), must_contain)

# check types for environment_config
task = environment_config['task'] if must_contain else 'None'
EnvironmentConfig.check_types(environment_config, task, False, must_contain)


# if __name__ == '__main__': # pragma: no cover
Expand Down
18 changes: 14 additions & 4 deletions recommerce/configuration/hyperparameter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def get_required_fields(cls, dict_key) -> dict:
return {'rl': True, 'sim_market': True}
elif dict_key == 'rl':
return {
'class': False,
'gamma': False,
'batch_size': False,
'replay_size': False,
Expand All @@ -30,6 +31,7 @@ def get_required_fields(cls, dict_key) -> dict:
}
elif dict_key == 'sim_market':
return {
'class': False,
'max_storage': False,
'episode_length': False,
'max_price': False,
Expand Down Expand Up @@ -77,20 +79,28 @@ def _check_demanded_config_is_subset_of_given(cls, config: dict, demanded_fields
assert key in config, f'your config is missing {key}'

@classmethod
def _check_types(cls, config: dict, configurable_fields: list) -> None:
def _check_types(cls, config: dict, configurable_fields: list, must_contain: bool = True) -> None:
for field_name, type, _ in configurable_fields:
assert isinstance(config[field_name], type), f'{field_name} must be a {type} but was {type(config[field_name])}'
try:
assert isinstance(config[field_name], type), f'{field_name} must be a {type} but was {type(config[field_name])}'
except KeyError as error:
if must_contain:
raise KeyError(f'Your config is missing the following required key: {field_name}') from error

@classmethod
def _check_rules(cls, config: dict, configurable_fields: list) -> None:
def _check_rules(cls, config: dict, configurable_fields: list, must_contain: bool = True) -> None:
for field_name, _, rule in configurable_fields:
if rule is not None:
if not isinstance(rule, tuple):
assert callable(rule)
check_method, error_string = rule(field_name)
else:
check_method, error_string = rule
assert check_method(config[field_name]), error_string
try:
assert check_method(config[field_name]), error_string
except KeyError as error:
if must_contain:
raise KeyError(f'Your config is missing the following required key: {field_name}') from error


class HyperparameterConfigLoader():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"rl": {
"class" : "recommerce.rl.q_learning.q_learning_agent.QLearningAgent",
"gamma" : 0.99,
"batch_size" : 32,
"replay_size" : 100000,
Expand All @@ -11,6 +12,7 @@
"epsilon_final" : 0.1
},
"sim_market": {
"class": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly",
"max_storage": 100,
"episode_length": 50,
"max_price": 10,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
{
"task": "training",
"sim_market": {
"class": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly",
"max_storage": 100,
"episode_length": 50
},
"enable_live_draw": false,
"rl": {
"class" : "recommerce.rl.q_learning.q_learning_agent.QLearningAgent",
"gamma" : 0.99,
"batch_size" : 32
},
Expand Down

0 comments on commit c90a8d6

Please sign in to comment.