From d3262458eae5d33c96e81bfbb4872d0f99ca9273 Mon Sep 17 00:00:00 2001 From: Judith <39854388+felix-20@users.noreply.github.com> Date: Wed, 1 Jun 2022 17:26:50 +0200 Subject: [PATCH 1/3] fix prefill --- .../alpha_business_app/adjustable_fields.py | 51 ++++++ webserver/alpha_business_app/buttons.py | 2 + .../models/agents_config.py | 3 + .../models/environment_config.py | 12 ++ .../models/hyperparameter_config.py | 9 + .../alpha_business_app/static/js/custom.js | 9 +- .../tests/constant_tests.py | 1 - .../tests/test_adjustable_fields.py | 16 ++ .../tests/test_config_model.py | 161 ++++++++++-------- .../tests/test_config_model_parser.py | 21 ++- webserver/alpha_business_app/utils.py | 15 +- webserver/alpha_business_app/views.py | 6 +- .../templates/configuration_items/rl.html | 110 +----------- 13 files changed, 205 insertions(+), 211 deletions(-) create mode 100644 webserver/alpha_business_app/adjustable_fields.py create mode 100644 webserver/alpha_business_app/tests/test_adjustable_fields.py diff --git a/webserver/alpha_business_app/adjustable_fields.py b/webserver/alpha_business_app/adjustable_fields.py new file mode 100644 index 00000000..dd97900a --- /dev/null +++ b/webserver/alpha_business_app/adjustable_fields.py @@ -0,0 +1,51 @@ +from recommerce.configuration.utils import get_class + +from .utils import convert_python_type_to_input_type + + +def get_agent_hyperparameter(agent: str, formdata: dict) -> list: + + # get all fields that are possible for this agent + agent_class = get_class(agent) + agent_specs = agent_class.get_configurable_fields() + + # we want to keep values already inside the html, so we need to parse existing html + parameter_values = _convert_form_to_value_dict(formdata) + # convert parameter into special list format for view + all_parameter = [] + for spec in agent_specs: + this_parameter = {} + this_parameter['name'] = spec[0] + this_parameter['input_type'] = convert_python_type_to_input_type(spec[1]) + this_parameter['prefill'] = _get_value_from_dict(spec[0], parameter_values) + all_parameter += [this_parameter] + return all_parameter + + +def get_rl_parameter_prefill(prefill: dict, error: dict) -> list: + # returns list of dictionaries + all_parameter = [] + for key, value in prefill.items(): + this_parameter = {} + this_parameter['name'] = key + this_parameter['prefill'] = value if value else '' + this_parameter['error'] = error[key] if error[key] else '' + all_parameter += [this_parameter] + return all_parameter + + +def _convert_form_to_value_dict(config_form) -> dict: + final_values = {} + for index in range((len(config_form) - 2) // 2): + current_name = config_form[f'formdata[{index}][name]'] + current_value = config_form[f'formdata[{index}][value]'] + if 'hyperparameter-rl' in current_name: + final_values[current_name.replace('hyperparameter-rl-', '')] = current_value + return final_values + + +def _get_value_from_dict(key, value_dict) -> dict: + try: + return value_dict[key] + except KeyError: + return '' diff --git a/webserver/alpha_business_app/buttons.py b/webserver/alpha_business_app/buttons.py index a31da27d..dc9e9f71 100644 --- a/webserver/alpha_business_app/buttons.py +++ b/webserver/alpha_business_app/buttons.py @@ -4,6 +4,7 @@ from recommerce.configuration.config_validation import validate_config +from .adjustable_fields import get_rl_parameter_prefill from .config_merger import ConfigMerger from .config_parser import ConfigFlatDictParser from .container_parser import parse_response_to_database @@ -280,6 +281,7 @@ def _prefill(self) -> HttpResponse: return self._decide_rendering() merger = ConfigMerger() final_dict, error_dict = merger.merge_config_objects(post_request['config_id']) + final_dict['hyperparameter']['rl'] = get_rl_parameter_prefill(final_dict['hyperparameter']['rl'], error_dict['hyperparameter']['rl']) # set an id for each agent (necessary for view) for agent_index in range(len(final_dict['environment']['agents'])): final_dict['environment']['agents'][agent_index]['display_name'] = 'Agent' if agent_index == 0 else 'Competitor' diff --git a/webserver/alpha_business_app/models/agents_config.py b/webserver/alpha_business_app/models/agents_config.py index daee858b..ccb251e7 100644 --- a/webserver/alpha_business_app/models/agents_config.py +++ b/webserver/alpha_business_app/models/agents_config.py @@ -7,3 +7,6 @@ class AgentsConfig(AbstractConfig, models.Model): def as_list(self) -> dict: referencing_agents = self.agentconfig_set.all() return [agent.as_dict() for agent in referencing_agents] + + def as_dict(self) -> dict: + assert False, 'This should not be implemented as agents are a list.' diff --git a/webserver/alpha_business_app/models/environment_config.py b/webserver/alpha_business_app/models/environment_config.py index 6a7502ca..d6f7aa86 100644 --- a/webserver/alpha_business_app/models/environment_config.py +++ b/webserver/alpha_business_app/models/environment_config.py @@ -1,5 +1,6 @@ from django.db import models +from ..utils import remove_none_values_from_dict from .abstract_config import AbstractConfig @@ -10,3 +11,14 @@ class EnvironmentConfig(AbstractConfig, models.Model): plot_interval = models.IntegerField(null=True) marketplace = models.CharField(max_length=150, null=True) task = models.CharField(max_length=14, choices=((1, 'training'), (2, 'agent_monitoring'), (3, 'exampleprinter')), null=True) + + def as_dict(self) -> dict: + agents_list = self.agents.as_list() if self.agents is not None else None + return remove_none_values_from_dict({ + 'enable_live_draw': self.enable_live_draw, + 'episodes': self.episodes, + 'plot_interval': self.plot_interval, + 'marketplace': self.marketplace, + 'task': self.task, + 'agents': agents_list + }) diff --git a/webserver/alpha_business_app/models/hyperparameter_config.py b/webserver/alpha_business_app/models/hyperparameter_config.py index 54613459..a8cc7267 100644 --- a/webserver/alpha_business_app/models/hyperparameter_config.py +++ b/webserver/alpha_business_app/models/hyperparameter_config.py @@ -1,5 +1,6 @@ from django.db import models +from ..utils import remove_none_values_from_dict from .abstract_config import AbstractConfig from .rl_config import RlConfig from .sim_market_config import SimMarketConfig @@ -8,3 +9,11 @@ class HyperparameterConfig(AbstractConfig, models.Model): rl = models.ForeignKey('alpha_business_app.RLConfig', on_delete=models.CASCADE, null=True) sim_market = models.ForeignKey('alpha_business_app.SimMarketConfig', on_delete=models.CASCADE, null=True) + + def as_dict(self) -> dict: + sim_market_dict = self.sim_market.as_dict() if self.sim_market is not None else {'sim_market': None} + rl_dict = self.rl.as_dict() if self.rl is not None else {'rl': None} + return remove_none_values_from_dict({ + 'rl': rl_dict, + 'sim_market': sim_market_dict + }) diff --git a/webserver/alpha_business_app/static/js/custom.js b/webserver/alpha_business_app/static/js/custom.js index dcf575d7..0bf63564 100644 --- a/webserver/alpha_business_app/static/js/custom.js +++ b/webserver/alpha_business_app/static/js/custom.js @@ -63,19 +63,20 @@ $(document).ready(function() { function addChangeToAgent () { $("select.agent-agent-class").change(function () { - // will be called when another marketplace has been selected + // will be called when agent dropdown has changed, we need to change rl hyperparameter for that var self = $(this); - console.log('here') + var form = $("form.config-form"); + var formdata = form.serializeArray(); const csrftoken = getCookie("csrftoken"); $.ajax({ type: "POST", url: self.data("url"), data: { csrfmiddlewaretoken: csrftoken, - "agent": self.val() + "agent": self.val(), + formdata }, success: function (data) { - console.log(data) $("div.rl-parameter").empty().append(data) } }); diff --git a/webserver/alpha_business_app/tests/constant_tests.py b/webserver/alpha_business_app/tests/constant_tests.py index 5dc4a0e2..810b5b51 100644 --- a/webserver/alpha_business_app/tests/constant_tests.py +++ b/webserver/alpha_business_app/tests/constant_tests.py @@ -142,7 +142,6 @@ EXAMPLE_RL_DICT = { 'rl': { - 'class': 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent', 'gamma': 0.99, 'batch_size': 32, 'replay_size': 100000, diff --git a/webserver/alpha_business_app/tests/test_adjustable_fields.py b/webserver/alpha_business_app/tests/test_adjustable_fields.py new file mode 100644 index 00000000..27a016e3 --- /dev/null +++ b/webserver/alpha_business_app/tests/test_adjustable_fields.py @@ -0,0 +1,16 @@ +from django.test import TestCase + +from ..adjustable_fields import get_rl_parameter_prefill + + +class AdjustableFieldsTests(TestCase): + def test_rl_hyperparameter_with_prefill(self): + prefill_dict = {'gamma': 0.9, 'learning_rate': 0.4, 'test': None} + error_dict = {'gamma': 'test', 'learning_rate': None, 'test': None} + expected_list = [ + {'name': 'gamma', 'prefill': 0.9, 'error': 'test'}, + {'name': 'learning_rate', 'prefill': 0.4, 'error': ''}, + {'name': 'test', 'prefill': '', 'error': ''} + ] + actual_list = get_rl_parameter_prefill(prefill_dict, error_dict) + assert actual_list == expected_list diff --git a/webserver/alpha_business_app/tests/test_config_model.py b/webserver/alpha_business_app/tests/test_config_model.py index aa2d4c51..c24988b1 100644 --- a/webserver/alpha_business_app/tests/test_config_model.py +++ b/webserver/alpha_business_app/tests/test_config_model.py @@ -5,7 +5,9 @@ from ..models.config import Config from ..models.container import Container from ..models.environment_config import EnvironmentConfig +from ..models.hyperparameter_config import HyperparameterConfig from ..models.rl_config import RlConfig +from ..models.sim_market_config import SimMarketConfig from ..utils import capitalize, remove_none_values_from_dict, to_config_class_name # from .constant_tests import EMPTY_STRUCTURE_CONFIG @@ -50,76 +52,76 @@ def test_is_referenced(self): assert test_config_not_referenced.is_referenced() is False assert test_config_referenced.is_referenced() is True - # def test_config_to_dict(self): - # # create a small valid config for this test - # agents_config = AgentsConfig.objects.create() - - # AgentConfig.objects.create(agent_class='recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', - # argument='', agents_config=agents_config, name='Rule_Based Agent') - - # env_config = EnvironmentConfig.objects.create(agents=agents_config, - # marketplace='recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly', - # task='training') - - # rl_config = RlConfig.objects.create(gamma=0.99, - # batch_size=32, - # replay_size=100000, - # learning_rate=1e-6, - # sync_target_frames=1000, - # replay_start_size=10000, - # epsilon_decay_last_frame=75000, - # epsilon_start=1.0, - # epsilon_final=0.1) - - # sim_market_config = SimMarketConfig.objects.create(max_storage=100, - # episode_length=50, - # max_price=10, - # max_quality=50, - # number_of_customers=20, - # production_price=3, - # storage_cost_per_product=0.1) - - # hyperparameter_config = HyperparameterConfig.objects.create(sim_market=sim_market_config, - # rl=rl_config) - - # final_config = Config.objects.create(environment=env_config, - # hyperparameter=hyperparameter_config) - # expected_dict = { - # 'hyperparameter': { - # 'rl': { - # 'gamma': 0.99, - # 'batch_size': 32, - # 'replay_size': 100000, - # 'learning_rate': 1e-6, - # 'sync_target_frames': 1000, - # 'replay_start_size': 10000, - # 'epsilon_decay_last_frame': 75000, - # 'epsilon_start': 1.0, - # 'epsilon_final': 0.1 - # }, - # 'sim_market': { - # 'max_storage': 100, - # 'episode_length': 50, - # 'max_price': 10, - # 'max_quality': 50, - # 'number_of_customers': 20, - # 'production_price': 3, - # 'storage_cost_per_product': 0.1 - # } - # }, - # 'environment': { - # 'task': 'training', - # 'marketplace': 'recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly', - # 'agents': [ - # { - # 'name': 'Rule_Based Agent', - # 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', - # 'argument': '' - # } - # ] - # } - # } - # assert expected_dict == final_config.as_dict() + def test_config_to_dict(self): + # create a small valid config for this test + agents_config = AgentsConfig.objects.create() + + AgentConfig.objects.create(agent_class='recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + argument='', agents_config=agents_config, name='Rule_Based Agent') + + env_config = EnvironmentConfig.objects.create(agents=agents_config, + marketplace='recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly', + task='training') + + rl_config = RlConfig.objects.create(gamma=0.99, + batch_size=32, + replay_size=100000, + learning_rate=1e-6, + sync_target_frames=1000, + replay_start_size=10000, + epsilon_decay_last_frame=75000, + epsilon_start=1.0, + epsilon_final=0.1) + + sim_market_config = SimMarketConfig.objects.create(max_storage=100, + episode_length=50, + max_price=10, + max_quality=50, + number_of_customers=20, + production_price=3, + storage_cost_per_product=0.1) + + hyperparameter_config = HyperparameterConfig.objects.create(sim_market=sim_market_config, + rl=rl_config) + + final_config = Config.objects.create(environment=env_config, + hyperparameter=hyperparameter_config) + expected_dict = { + 'hyperparameter': { + 'rl': { + 'gamma': 0.99, + 'batch_size': 32, + 'replay_size': 100000, + 'learning_rate': 1e-6, + 'sync_target_frames': 1000, + 'replay_start_size': 10000, + 'epsilon_decay_last_frame': 75000, + 'epsilon_start': 1.0, + 'epsilon_final': 0.1 + }, + 'sim_market': { + 'max_storage': 100, + 'episode_length': 50, + 'max_price': 10, + 'max_quality': 50, + 'number_of_customers': 20, + 'production_price': 3, + 'storage_cost_per_product': 0.1 + } + }, + 'environment': { + 'task': 'training', + 'marketplace': 'recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly', + 'agents': [ + { + 'name': 'Rule_Based Agent', + 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + 'argument': '' + } + ] + } + } + assert expected_dict == final_config.as_dict() def test_dict_representation_of_agent(self): test_agent = AgentConfig.objects.create(name='test_agent', agent_class='test_class', argument='1234') @@ -154,12 +156,21 @@ def test_list_representation_of_agents(self): ] assert expected_list == test_agents.as_list() - # def test_get_empty_structure_dict(self): - # actual_dict = Config.get_empty_structure_dict() - # assert EMPTY_STRUCTURE_CONFIG == actual_dict - def test_get_empty_structure_dict_for_rl(self): - expected_dict = {} + expected_dict = { + 'sync_target_frames': None, + 'testvalue2': None, + 'gamma': None, + 'epsilon_start': None, + 'replay_size': None, + 'stable_baseline_test': None, + 'threshold': None, + 'epsilon_decay_last_frame': None, + 'batch_size': None, + 'epsilon_final': None, + 'replay_start_size': None, + 'learning_rate': None + } assert expected_dict == RlConfig.get_empty_structure_dict() def test_remove_none_values_from_dict(self): diff --git a/webserver/alpha_business_app/tests/test_config_model_parser.py b/webserver/alpha_business_app/tests/test_config_model_parser.py index 7c617530..9be3515a 100644 --- a/webserver/alpha_business_app/tests/test_config_model_parser.py +++ b/webserver/alpha_business_app/tests/test_config_model_parser.py @@ -4,6 +4,7 @@ from ..models.agents_config import AgentsConfig from ..models.config import Config from ..models.environment_config import EnvironmentConfig +from ..models.hyperparameter_config import HyperparameterConfig from ..models.rl_config import RlConfig from ..models.sim_market_config import SimMarketConfig from .constant_tests import EXAMPLE_HIERARCHY_DICT, EXAMPLE_RL_DICT @@ -128,13 +129,23 @@ def test_parsing_agents(self): def test_parse_rl(self): test_dict = EXAMPLE_RL_DICT.copy() - final_config = self.parser.parse_config(test_dict) + final_config = self.parser.parse_config_dict_to_datastructure('hyperparameter', test_dict) - assert Config == type(final_config) + assert HyperparameterConfig == type(final_config) # assert all hyperparameters - hyperparameter_rl_config: RlConfig = final_config.hyperparameter.rl - hyperparameter_sim_market_config: SimMarketConfig = final_config.hyperparameter.sim_market - + hyperparameter_rl_config: RlConfig = final_config.rl + hyperparameter_sim_market_config: SimMarketConfig = final_config.sim_market assert hyperparameter_rl_config is not None assert hyperparameter_sim_market_config is None + + assert 0.99 == hyperparameter_rl_config.gamma + assert 32 == hyperparameter_rl_config.batch_size + assert 100000 == hyperparameter_rl_config.replay_size + assert 1e-6 == hyperparameter_rl_config.learning_rate + assert 1000 == hyperparameter_rl_config.sync_target_frames + assert 10000 == hyperparameter_rl_config.replay_start_size + assert 75000 == hyperparameter_rl_config.epsilon_decay_last_frame + assert 1.0 == hyperparameter_rl_config.epsilon_start + assert 0.1 == hyperparameter_rl_config.epsilon_final + assert hyperparameter_rl_config.stable_baseline_test is None diff --git a/webserver/alpha_business_app/utils.py b/webserver/alpha_business_app/utils.py index 54e98d6c..c2dc52ec 100644 --- a/webserver/alpha_business_app/utils.py +++ b/webserver/alpha_business_app/utils.py @@ -28,20 +28,7 @@ def get_recommerce_agents_for_marketplace(marketplace) -> list: return marketplace.get_possible_rl_agents() -def get_agent_hyperparameter(agent: str) -> list: - agent_class = get_class(agent) - agent_specs = agent_class.get_configurable_fields() - # name, input_type, error - all_parameter = [] - for spec in agent_specs: - this_parameter = {} - this_parameter['name'] = spec[0] - this_parameter['input_type'] = _convert_python_type_to_input_type(spec[1]) - all_parameter += [this_parameter] - return all_parameter - - -def _convert_python_type_to_input_type(to_convert) -> str: +def convert_python_type_to_input_type(to_convert) -> str: return 'number' if to_convert == float or to_convert == int else 'text' diff --git a/webserver/alpha_business_app/views.py b/webserver/alpha_business_app/views.py index 2b107fe4..80a5e66a 100644 --- a/webserver/alpha_business_app/views.py +++ b/webserver/alpha_business_app/views.py @@ -6,6 +6,7 @@ from recommerce.configuration.config_validation import validate_config +from .adjustable_fields import get_agent_hyperparameter from .buttons import ButtonHandler from .config_parser import ConfigFlatDictParser from .forms import UploadFileForm @@ -14,7 +15,6 @@ from .models.config import Config from .models.container import Container from .selection_manager import SelectionManager -from .utils import get_agent_hyperparameter selection_manager = SelectionManager() @@ -106,9 +106,9 @@ def new_agent(request) -> HttpResponse: def agent_changed(request) -> HttpResponse: if not request.user.is_authenticated: return HttpResponse('Unauthorized', status=401) - print(get_agent_hyperparameter(request.POST['agent'])) + # print(request.POST) return render(request, 'configuration_items/rl_parameter.html', - {'parameters': get_agent_hyperparameter(request.POST['agent'])}) + {'parameters': get_agent_hyperparameter(request.POST['agent'], request.POST.dict())}) def api_availability(request) -> HttpResponse: diff --git a/webserver/templates/configuration_items/rl.html b/webserver/templates/configuration_items/rl.html index ebab9862..cc2f781a 100644 --- a/webserver/templates/configuration_items/rl.html +++ b/webserver/templates/configuration_items/rl.html @@ -6,115 +6,7 @@

- {% load static %} -
-
- {% if error_dict.gamma %} - - {% endif %} - gamma -
-
- -
-
-
-
- {% if error_dict.batch_size %} - - {% endif %} - batch size -
-
- -
-
-
-
- {% if error_dict.replay_size %} - - {% endif %} - replay size -
-
- -
-
-
-
- {% if error_dict.learning_rate %} - - {% endif %} - learning rate -
-
- -
-
-
-
- {% if error_dict.sync_target_frames %} - - {% endif %} - sync target frames -
-
- -
-
-
-
- {% if error_dict.start_size %} - - {% endif %} - replay start size -
-
- -
-
-
-
- {% if error_dict.epsilon_decay_last_frame %} - - {% endif %} - epsilon decay last frame -
-
- -
-
-
-
- {% if error_dict.epsilon_start %} - - {% endif %} - epsilon start -
-
- -
-
-
-
- {% if error_dict.epsilon_final %} - - {% endif %} - epsilon final -
-
- -
-
+ {% include "configuration_items/rl_parameter.html" with parameters=prefill %}
From 135e8f300a5da0f853f07ddc4faec7ae3a13d0e9 Mon Sep 17 00:00:00 2001 From: Judith <39854388+felix-20@users.noreply.github.com> Date: Wed, 1 Jun 2022 20:52:12 +0200 Subject: [PATCH 2/3] fix webserver tests --- recommerce/configuration/config_validation.py | 2 +- tests/test_vendors.py | 1 - webserver/alpha_business_app/buttons.py | 2 +- webserver/alpha_business_app/config_merger.py | 7 +- ...eline_test_rlconfig_testvalue2_and_more.py | 68 +++++++++++++++++++ .../alpha_business_app/models/rl_config.py | 11 ++- .../tests/constant_tests.py | 12 ++-- .../alpha_business_app/tests/test_prefill.py | 21 ++++-- .../alpha_business_app/tests/test_utils.py | 6 +- 9 files changed, 102 insertions(+), 28 deletions(-) create mode 100644 webserver/alpha_business_app/migrations/0013_rlconfig_stable_baseline_test_rlconfig_testvalue2_and_more.py diff --git a/recommerce/configuration/config_validation.py b/recommerce/configuration/config_validation.py index 692cd0cc..8dcd18b2 100644 --- a/recommerce/configuration/config_validation.py +++ b/recommerce/configuration/config_validation.py @@ -34,7 +34,7 @@ def validate_config(config: dict, config_is_final: bool) -> tuple: if 'marketplace' in environment_config: market_class = get_class(environment_config['marketplace']) # the first agent is always the relevant one - if 'agents' in environment_config and len(environment_config['agents'] >= 1): + if 'agents' in environment_config and len(environment_config['agents']) >= 1: agent_class = get_class(environment_config['agents'][0]['agent_class']) # validate that all given values have the correct types diff --git a/tests/test_vendors.py b/tests/test_vendors.py index 4cd5ed80..f8d65dbe 100644 --- a/tests/test_vendors.py +++ b/tests/test_vendors.py @@ -103,7 +103,6 @@ def test_storage_evaluation_with_rebuy_price(state, expected_prices): changed_config.max_price = 10 changed_config.production_price = 2 agent = circular_vendors.RuleBasedCERebuyAgent(config_market=changed_config) - print('*********************************') assert expected_prices == agent.policy(state) diff --git a/webserver/alpha_business_app/buttons.py b/webserver/alpha_business_app/buttons.py index dc9e9f71..46c72e46 100644 --- a/webserver/alpha_business_app/buttons.py +++ b/webserver/alpha_business_app/buttons.py @@ -280,7 +280,7 @@ def _prefill(self) -> HttpResponse: if 'config_id' not in post_request: return self._decide_rendering() merger = ConfigMerger() - final_dict, error_dict = merger.merge_config_objects(post_request['config_id']) + final_dict, error_dict = merger.merge_config_objects(post_request['config_id'], post_request) final_dict['hyperparameter']['rl'] = get_rl_parameter_prefill(final_dict['hyperparameter']['rl'], error_dict['hyperparameter']['rl']) # set an id for each agent (necessary for view) for agent_index in range(len(final_dict['environment']['agents'])): diff --git a/webserver/alpha_business_app/config_merger.py b/webserver/alpha_business_app/config_merger.py index fc27e511..4a239d55 100644 --- a/webserver/alpha_business_app/config_merger.py +++ b/webserver/alpha_business_app/config_merger.py @@ -17,14 +17,13 @@ def merge_config_objects(self, config_object_ids: list) -> tuple: """ configuration_objects = [Config.objects.get(id=config_id) for config_id in config_object_ids] configuration_dicts = [config.as_dict() for config in configuration_objects] - for c in configuration_dicts: - print(c) - print('--------------------------------------') + # for c in configuration_dicts: + # print(c) + # print('--------------------------------------') # get initial empty dict to merge into final_config = Config.get_empty_structure_dict() for config in configuration_dicts: final_config = self._merge_config_into_base_config(final_config, config) - print(final_config, '\n*******************************\n') return final_config, self.error_dict def _merge_config_into_base_config(self, base_config: dict, merging_config: dict, current_config_path: str = '') -> dict: diff --git a/webserver/alpha_business_app/migrations/0013_rlconfig_stable_baseline_test_rlconfig_testvalue2_and_more.py b/webserver/alpha_business_app/migrations/0013_rlconfig_stable_baseline_test_rlconfig_testvalue2_and_more.py new file mode 100644 index 00000000..823fa036 --- /dev/null +++ b/webserver/alpha_business_app/migrations/0013_rlconfig_stable_baseline_test_rlconfig_testvalue2_and_more.py @@ -0,0 +1,68 @@ +# Generated by Django 4.0.1 on 2022-06-01 18:36 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('alpha_business_app', '0012_config_user_container_user'), + ] + + operations = [ + migrations.AddField( + model_name='rlconfig', + name='stable_baseline_test', + field=models.FloatField(default=None, null=True), + ), + migrations.AddField( + model_name='rlconfig', + name='testvalue2', + field=models.FloatField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='batch_size', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='epsilon_decay_last_frame', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='epsilon_final', + field=models.FloatField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='epsilon_start', + field=models.FloatField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='gamma', + field=models.FloatField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='learning_rate', + field=models.FloatField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='replay_size', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='replay_start_size', + field=models.IntegerField(default=None, null=True), + ), + migrations.AlterField( + model_name='rlconfig', + name='sync_target_frames', + field=models.IntegerField(default=None, null=True), + ), + ] diff --git a/webserver/alpha_business_app/models/rl_config.py b/webserver/alpha_business_app/models/rl_config.py index e66fcdf6..d281f39a 100644 --- a/webserver/alpha_business_app/models/rl_config.py +++ b/webserver/alpha_business_app/models/rl_config.py @@ -4,15 +4,14 @@ class RlConfig(AbstractConfig, models.Model): + stable_baseline_test = models.FloatField(null=True, default=None) + epsilon_decay_last_frame = models.IntegerField(null=True, default=None) epsilon_start = models.FloatField(null=True, default=None) sync_target_frames = models.IntegerField(null=True, default=None) replay_start_size = models.IntegerField(null=True, default=None) + testvalue2 = models.FloatField(null=True, default=None) + gamma = models.FloatField(null=True, default=None) + epsilon_final = models.FloatField(null=True, default=None) replay_size = models.IntegerField(null=True, default=None) batch_size = models.IntegerField(null=True, default=None) - stable_baseline_test = models.FloatField(null=True, default=None) learning_rate = models.FloatField(null=True, default=None) - epsilon_decay_last_frame = models.IntegerField(null=True, default=None) - testvalue2 = models.FloatField(null=True, default=None) - epsilon_final = models.FloatField(null=True, default=None) - threshold = models.FloatField(null=True, default=None) - gamma = models.FloatField(null=True, default=None) diff --git a/webserver/alpha_business_app/tests/constant_tests.py b/webserver/alpha_business_app/tests/constant_tests.py index 810b5b51..d636024a 100644 --- a/webserver/alpha_business_app/tests/constant_tests.py +++ b/webserver/alpha_business_app/tests/constant_tests.py @@ -7,8 +7,8 @@ 'environment-episodes': [''], 'environment-plot_interval': [''], 'environment-marketplace': ['recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly'], - 'environment-agents-name': ['Rule_Based Agent'], - 'environment-agents-agent_class': ['recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent'], + 'environment-agents-name': ['QLearning Agent'], + 'environment-agents-agent_class': ['recommerce.rl.q_learning.q_learning_agent.QLearningAgent'], 'environment-agents-argument': [''], 'hyperparameter-rl-gamma': ['0.99'], 'hyperparameter-rl-batch_size': ['32'], @@ -35,8 +35,8 @@ 'enable_live_draw': False, 'agents': [ { - 'name': 'Rule_Based Agent', - 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + 'name': 'QLearning Agent', + 'agent_class': 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent', 'argument': '' } ] @@ -126,7 +126,9 @@ 'replay_start_size': None, 'epsilon_decay_last_frame': None, 'epsilon_start': None, - 'epsilon_final': None + 'epsilon_final': None, + 'testvalue2': None, + 'stable_baseline_test': None }, 'sim_market': { 'max_storage': None, diff --git a/webserver/alpha_business_app/tests/test_prefill.py b/webserver/alpha_business_app/tests/test_prefill.py index 7e0f215a..aeb9b92a 100644 --- a/webserver/alpha_business_app/tests/test_prefill.py +++ b/webserver/alpha_business_app/tests/test_prefill.py @@ -8,7 +8,10 @@ # from ..buttons import ButtonHandler from ..config_merger import ConfigMerger from ..config_parser import ConfigModelParser -from ..models.config import Config, EnvironmentConfig, HyperparameterConfig, RlConfig +from ..models.config import Config +from ..models.environment_config import EnvironmentConfig +from ..models.hyperparameter_config import HyperparameterConfig +from ..models.rl_config import RlConfig from .constant_tests import EMPTY_STRUCTURE_CONFIG, EXAMPLE_HIERARCHY_DICT, EXAMPLE_HIERARCHY_DICT2 # from unittest.mock import patch @@ -35,6 +38,9 @@ def test_merge_one_config(self): expected_dict = copy.deepcopy(config_dict) expected_dict['environment']['episodes'] = None expected_dict['environment']['plot_interval'] = None + expected_dict['hyperparameter']['rl']['testvalue2'] = None + expected_dict['hyperparameter']['rl']['stable_baseline_test'] = None + empty_config = Config.get_empty_structure_dict() merger = ConfigMerger() actual_config = merger._merge_config_into_base_config(empty_config, config_dict) @@ -98,8 +104,8 @@ def test_merge_two_configs_with_conflicts(self): 'task': 'monitoring', 'agents': [ { - 'name': 'Rule_Based Agent', - 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent', + 'name': 'QLearning Agent', + 'agent_class': 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent', 'argument': '' }, { @@ -122,7 +128,10 @@ def test_merge_two_configs_with_conflicts(self): 'sync_target_frames': 100, 'replay_start_size': 1000, 'epsilon_decay_last_frame': 7500, - 'epsilon_start': 0.9, 'epsilon_final': 0.2 + 'epsilon_start': 0.9, + 'epsilon_final': 0.2, + 'testvalue2': None, + 'stable_baseline_test': None }, 'sim_market': { 'max_storage': 80, @@ -153,7 +162,9 @@ def test_merge_two_configs_with_conflicts(self): 'replay_start_size': 'changed hyperparameter-rl replay_start_size from 10000 to 1000', 'epsilon_decay_last_frame': 'changed hyperparameter-rl epsilon_decay_last_frame from 75000 to 7500', 'epsilon_start': 'changed hyperparameter-rl epsilon_start from 1.0 to 0.9', - 'epsilon_final': 'changed hyperparameter-rl epsilon_final from 0.1 to 0.2' + 'epsilon_final': 'changed hyperparameter-rl epsilon_final from 0.1 to 0.2', + 'testvalue2': None, + 'stable_baseline_test': None }, 'sim_market': { 'max_storage': 'changed hyperparameter-sim_market max_storage from 100 to 80', diff --git a/webserver/alpha_business_app/tests/test_utils.py b/webserver/alpha_business_app/tests/test_utils.py index 4da233c8..21345127 100644 --- a/webserver/alpha_business_app/tests/test_utils.py +++ b/webserver/alpha_business_app/tests/test_utils.py @@ -4,7 +4,7 @@ class UtilsTest(TestCase): - def test_get_structure_dict_for(self): + def test_get_structure_dict_for_config(self): expected_dict = { 'environment': { 'task': None, @@ -16,7 +16,6 @@ def test_get_structure_dict_for(self): }, 'hyperparameter': { 'sim_market': { - 'class': None, 'max_storage': None, 'episode_length': None, 'max_price': None, @@ -33,7 +32,6 @@ def test_get_structure_dict_for(self): 'testvalue2': None, 'sync_target_frames': None, 'batch_size': None, - 'threshold': None, 'epsilon_final': None, 'stable_baseline_test': None, 'gamma': None, @@ -52,7 +50,6 @@ def test_get_structure_dict_for_rl(self): 'testvalue2': None, 'sync_target_frames': None, 'batch_size': None, - 'threshold': None, 'epsilon_final': None, 'stable_baseline_test': None, 'gamma': None, @@ -62,7 +59,6 @@ def test_get_structure_dict_for_rl(self): def test_get_structure_dict_for_sim_market(self): expected_dict = { - 'class': None, 'max_storage': None, 'episode_length': None, 'max_price': None, From f50fe65fccbed1db759c42ffb6e3afd0e277fe4a Mon Sep 17 00:00:00 2001 From: Judith <39854388+felix-20@users.noreply.github.com> Date: Thu, 2 Jun 2022 11:21:10 +0200 Subject: [PATCH 3/3] implement new config validation for webserver --- recommerce/configuration/config_validation.py | 2 + webserver/alpha_business_app/handle_files.py | 30 ++++++--- .../tests/test_config_flat_dict_parser.py | 4 +- .../tests/test_config_model.py | 1 - .../tests/test_config_model_parser.py | 4 +- .../test_data/test_environment_config.json | 5 +- .../test_data/test_hyperparameter_config.json | 24 ------- .../tests/test_data/test_mixed_config.json | 15 ----- .../tests/test_data/test_rl_config.json | 12 ++++ .../test_data/test_sim_market_config.json | 11 ++++ .../tests/test_file_handling.py | 66 +++++++++---------- 11 files changed, 84 insertions(+), 90 deletions(-) delete mode 100644 webserver/alpha_business_app/tests/test_data/test_hyperparameter_config.json delete mode 100644 webserver/alpha_business_app/tests/test_data/test_mixed_config.json create mode 100644 webserver/alpha_business_app/tests/test_data/test_rl_config.json create mode 100644 webserver/alpha_business_app/tests/test_data/test_sim_market_config.json diff --git a/recommerce/configuration/config_validation.py b/recommerce/configuration/config_validation.py index 8dcd18b2..96de3a05 100644 --- a/recommerce/configuration/config_validation.py +++ b/recommerce/configuration/config_validation.py @@ -31,6 +31,8 @@ def validate_config(config: dict, config_is_final: bool) -> tuple: # try to split the config. If any keys are unknown, an AssertionError will be thrown hyperparameter_config, environment_config = split_mixed_config(config) + agent_class = None + market_class = None if 'marketplace' in environment_config: market_class = get_class(environment_config['marketplace']) # the first agent is always the relevant one diff --git a/webserver/alpha_business_app/handle_files.py b/webserver/alpha_business_app/handle_files.py index c7d095b6..6a523135 100644 --- a/webserver/alpha_business_app/handle_files.py +++ b/webserver/alpha_business_app/handle_files.py @@ -56,23 +56,37 @@ def handle_uploaded_file(request, uploaded_config) -> HttpResponse: except ValueError as value: return render(request, 'upload.html', {'error': str(value)}) + # Validate the config file using the recommerce validation functionality validate_status, validate_data = validate_config(content_as_dict, False) if not validate_status: return render(request, 'upload.html', {'error': validate_data}) - hyperparameter_config, environment_config = validate_data + config = validate_data + assert len(config.keys()) == 1, f'This config ({config} as multiple keys, should only be one ("environment", "rl" or "sim_market"))' + # recommerce returns either {'environment': {}}, {'rl': {}} or {'sim_markte': {}}} + # for parsing we need to know what the toplevel key is + top_level = config.keys()[0] + if top_level != 'environment': + top_level = 'hyperparameter' + + # parse config model to datastructure parser = ConfigModelParser() - web_hyperparameter_config = None - web_environment_config = None try: - web_hyperparameter_config = parser.parse_config_dict_to_datastructure('hyperparameter', hyperparameter_config) - web_environment_config = parser.parse_config_dict_to_datastructure('environment', environment_config) - except ValueError: - return render(request, 'upload.html', {'error': 'Your config is wrong'}) + resulting_config_part = parser.parse_config_dict_to_datastructure(top_level, config) + except ValueError as e: + return render(request, 'upload.html', {'error': f'Your config is wrong {e}'}) + + # Make it a real config object + environment_config = None + hyperparameter_config = None + if top_level == 'environment': + environment_config = resulting_config_part + else: + hyperparameter_config = resulting_config_part given_name = request.POST['config_name'] config_name = given_name if given_name else uploaded_config.name - Config.objects.create(environment=web_environment_config, hyperparameter=web_hyperparameter_config, name=config_name, user=request.user) + Config.objects.create(environment=environment_config, hyperparameter=hyperparameter_config, name=config_name, user=request.user) return redirect('/configurator', {'success': 'You successfully uploaded a config file'}) diff --git a/webserver/alpha_business_app/tests/test_config_flat_dict_parser.py b/webserver/alpha_business_app/tests/test_config_flat_dict_parser.py index 58d52fc8..cd150967 100644 --- a/webserver/alpha_business_app/tests/test_config_flat_dict_parser.py +++ b/webserver/alpha_business_app/tests/test_config_flat_dict_parser.py @@ -209,8 +209,8 @@ def test_parsing_config_dict(self): all_agents = environment_agents.agentconfig_set.all() assert 1 == len(all_agents) - assert 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent' == all_agents[0].agent_class - assert 'Rule_Based Agent' == all_agents[0].name + assert 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent' == all_agents[0].agent_class + assert 'QLearning Agent' == all_agents[0].name assert '' == all_agents[0].argument def test_parsing_agents(self): diff --git a/webserver/alpha_business_app/tests/test_config_model.py b/webserver/alpha_business_app/tests/test_config_model.py index c24988b1..da8287f7 100644 --- a/webserver/alpha_business_app/tests/test_config_model.py +++ b/webserver/alpha_business_app/tests/test_config_model.py @@ -164,7 +164,6 @@ def test_get_empty_structure_dict_for_rl(self): 'epsilon_start': None, 'replay_size': None, 'stable_baseline_test': None, - 'threshold': None, 'epsilon_decay_last_frame': None, 'batch_size': None, 'epsilon_final': None, diff --git a/webserver/alpha_business_app/tests/test_config_model_parser.py b/webserver/alpha_business_app/tests/test_config_model_parser.py index 9be3515a..4132894e 100644 --- a/webserver/alpha_business_app/tests/test_config_model_parser.py +++ b/webserver/alpha_business_app/tests/test_config_model_parser.py @@ -98,8 +98,8 @@ def test_parsing_config_dict(self): all_agents = environment_agents.agentconfig_set.all() assert 1 == len(all_agents) - assert 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent' == all_agents[0].agent_class - assert 'Rule_Based Agent' == all_agents[0].name + assert 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent' == all_agents[0].agent_class + assert 'QLearning Agent' == all_agents[0].name assert '' == all_agents[0].argument def test_parsing_agents(self): diff --git a/webserver/alpha_business_app/tests/test_data/test_environment_config.json b/webserver/alpha_business_app/tests/test_data/test_environment_config.json index edfa9644..b06c4d07 100644 --- a/webserver/alpha_business_app/tests/test_data/test_environment_config.json +++ b/webserver/alpha_business_app/tests/test_data/test_environment_config.json @@ -6,7 +6,7 @@ "marketplace": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly", "agents": [ { - "name": "Rule_Based Agent", + "name": "QLearning Agent", "agent_class": "recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent", "argument": "" }, @@ -15,5 +15,6 @@ "agent_class": "recommerce.rl.q_learning.q_learning_agent.QLearningAgent", "argument": "CircularEconomyRebuyPriceMonopoly_QLearningAgent.dat" } - ] + ], + "config_type": "environment" } diff --git a/webserver/alpha_business_app/tests/test_data/test_hyperparameter_config.json b/webserver/alpha_business_app/tests/test_data/test_hyperparameter_config.json deleted file mode 100644 index 11278d8b..00000000 --- a/webserver/alpha_business_app/tests/test_data/test_hyperparameter_config.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "rl": { - "class" : "recommerce.rl.q_learning.q_learning_agent.QLearningAgent", - "gamma" : 0.99, - "batch_size" : 32, - "replay_size" : 100000, - "learning_rate" : 1e-6, - "sync_target_frames" : 1000, - "replay_start_size" : 10000, - "epsilon_decay_last_frame" : 75000, - "epsilon_start" : 1.0, - "epsilon_final" : 0.1 - }, - "sim_market": { - "class": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly", - "max_storage": 100, - "episode_length": 50, - "max_price": 10, - "max_quality": 50, - "number_of_customers": 20, - "production_price": 3, - "storage_cost_per_product": 0.1 - } -} diff --git a/webserver/alpha_business_app/tests/test_data/test_mixed_config.json b/webserver/alpha_business_app/tests/test_data/test_mixed_config.json deleted file mode 100644 index c58239e5..00000000 --- a/webserver/alpha_business_app/tests/test_data/test_mixed_config.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "task": "training", - "sim_market": { - "class": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly", - "max_storage": 100, - "episode_length": 50 - }, - "enable_live_draw": false, - "rl": { - "class" : "recommerce.rl.q_learning.q_learning_agent.QLearningAgent", - "gamma" : 0.99, - "batch_size" : 32 - }, - "episodes": 50 -} diff --git a/webserver/alpha_business_app/tests/test_data/test_rl_config.json b/webserver/alpha_business_app/tests/test_data/test_rl_config.json new file mode 100644 index 00000000..7d1eea3b --- /dev/null +++ b/webserver/alpha_business_app/tests/test_data/test_rl_config.json @@ -0,0 +1,12 @@ +{ + "gamma" : 0.99, + "batch_size" : 32, + "replay_size" : 100000, + "learning_rate" : 1e-6, + "sync_target_frames" : 1000, + "replay_start_size" : 10000, + "epsilon_decay_last_frame" : 75000, + "epsilon_start" : 1.0, + "epsilon_final" : 0.1, + "config_type": "rl" +} diff --git a/webserver/alpha_business_app/tests/test_data/test_sim_market_config.json b/webserver/alpha_business_app/tests/test_data/test_sim_market_config.json new file mode 100644 index 00000000..e1642d9b --- /dev/null +++ b/webserver/alpha_business_app/tests/test_data/test_sim_market_config.json @@ -0,0 +1,11 @@ +{ + "class": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly", + "max_storage": 100, + "episode_length": 50, + "max_price": 10, + "max_quality": 50, + "number_of_customers": 20, + "production_price": 3, + "storage_cost_per_product": 0.1, + "config_type": "market" +} diff --git a/webserver/alpha_business_app/tests/test_file_handling.py b/webserver/alpha_business_app/tests/test_file_handling.py index 82aaa0b5..780ddda8 100644 --- a/webserver/alpha_business_app/tests/test_file_handling.py +++ b/webserver/alpha_business_app/tests/test_file_handling.py @@ -8,7 +8,11 @@ from ..config_parser import ConfigModelParser from ..handle_files import handle_uploaded_file -from ..models.config import * +from ..models.agents_config import AgentsConfig +from ..models.config import Config +from ..models.environment_config import EnvironmentConfig +from ..models.rl_config import RlConfig +from ..models.sim_market_config import SimMarketConfig class MockedResponse(): @@ -92,10 +96,10 @@ def test_objects_from_parse_dict(self): else: assert 32 == getattr(resulting_config.rl, name) - def test_parsing_with_only_hyperparameter(self): + def test_parsing_with_rl_hyperparameter(self): # get a test config to be parsed path_to_test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data') - with open(os.path.join(path_to_test_data, 'test_hyperparameter_config.json'), 'r') as file: + with open(os.path.join(path_to_test_data, 'test_rl_config.json'), 'r') as file: content = file.read() # mock uploaded file with test config test_uploaded_file = MockedUploadedFile('config.json', content.encode()) @@ -108,9 +112,9 @@ def test_parsing_with_only_hyperparameter(self): assert Config == type(final_config) assert final_config.environment is None assert final_config.hyperparameter is not None + assert final_config.hyperparameter.sim_market is None hyperparameter_rl_config: RlConfig = final_config.hyperparameter.rl - hyperparameter_sim_market_config: SimMarketConfig = final_config.hyperparameter.sim_market assert hyperparameter_rl_config is not None assert final_config.hyperparameter.sim_market is not None @@ -125,6 +129,28 @@ def test_parsing_with_only_hyperparameter(self): assert 1.0 == hyperparameter_rl_config.epsilon_start assert 0.1 == hyperparameter_rl_config.epsilon_final + def test_parsing_with_sim_market_hyperparameter(self): + # get a test config to be parsed + path_to_test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data') + with open(os.path.join(path_to_test_data, 'test_sim_market_config.json'), 'r') as file: + content = file.read() + # mock uploaded file with test config + test_uploaded_file = MockedUploadedFile('config.json', content.encode()) + # test method + with patch('alpha_business_app.handle_files.redirect') as redirect_mock: + handle_uploaded_file(self._setup_request(), test_uploaded_file) + redirect_mock.assert_called_once() + # assert the datastructure, that should be present afterwards + final_config: Config = Config.objects.all().first() + assert Config == type(final_config) + assert final_config.environment is None + assert final_config.hyperparameter is not None + assert final_config.hyperparameter.rl is None + + hyperparameter_sim_market_config: SimMarketConfig = final_config.hyperparameter.sim_market + + assert final_config.hyperparameter.sim_market is not None + assert 100 == hyperparameter_sim_market_config.max_storage assert 50 == hyperparameter_sim_market_config.episode_length assert 10 == hyperparameter_sim_market_config.max_price @@ -168,38 +194,6 @@ def test_parsing_with_only_environment(self): assert 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent' == all_agents[1].agent_class assert 'CircularEconomyRebuyPriceMonopoly_QLearningAgent.dat' == all_agents[1].argument - def test_parsing_mixed_config(self): - # get a test config to be parsed - path_to_test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data') - with open(os.path.join(path_to_test_data, 'test_mixed_config.json'), 'r') as file: - content = file.read() - # mock uploaded file with test config - test_uploaded_file = MockedUploadedFile('config.json', content.encode()) - # test method - with patch('alpha_business_app.handle_files.redirect') as redirect_mock: - handle_uploaded_file(self._setup_request(), test_uploaded_file) - redirect_mock.assert_called_once() - # assert the datastructure, that should be present afterwards - final_config: Config = Config.objects.all().first() - assert Config == type(final_config) - assert final_config.environment is not None - assert final_config.hyperparameter is not None - - environment_config: EnvironmentConfig = final_config.environment - hyperparameter_config: HyperparameterConfig = final_config.hyperparameter - - assert 'training' == environment_config.task - assert environment_config.enable_live_draw is False - assert 50 == environment_config.episodes - - assert hyperparameter_config.sim_market is not None - assert hyperparameter_config.rl is not None - - assert 100 == hyperparameter_config.sim_market.max_storage - assert 50 == hyperparameter_config.sim_market.episode_length - assert 0.99 == hyperparameter_config.rl.gamma - assert 32 == hyperparameter_config.rl.batch_size - def test_parsing_invalid_rl_parameters(self): test_uploaded_file = MockedUploadedFile('config.json', b'{"rl": {"test":"bla"}}') with patch('alpha_business_app.handle_files.render') as render_mock: