From d3262458eae5d33c96e81bfbb4872d0f99ca9273 Mon Sep 17 00:00:00 2001
From: Judith <39854388+felix-20@users.noreply.github.com>
Date: Wed, 1 Jun 2022 17:26:50 +0200
Subject: [PATCH 1/3] fix prefill
---
.../alpha_business_app/adjustable_fields.py | 51 ++++++
webserver/alpha_business_app/buttons.py | 2 +
.../models/agents_config.py | 3 +
.../models/environment_config.py | 12 ++
.../models/hyperparameter_config.py | 9 +
.../alpha_business_app/static/js/custom.js | 9 +-
.../tests/constant_tests.py | 1 -
.../tests/test_adjustable_fields.py | 16 ++
.../tests/test_config_model.py | 161 ++++++++++--------
.../tests/test_config_model_parser.py | 21 ++-
webserver/alpha_business_app/utils.py | 15 +-
webserver/alpha_business_app/views.py | 6 +-
.../templates/configuration_items/rl.html | 110 +-----------
13 files changed, 205 insertions(+), 211 deletions(-)
create mode 100644 webserver/alpha_business_app/adjustable_fields.py
create mode 100644 webserver/alpha_business_app/tests/test_adjustable_fields.py
diff --git a/webserver/alpha_business_app/adjustable_fields.py b/webserver/alpha_business_app/adjustable_fields.py
new file mode 100644
index 00000000..dd97900a
--- /dev/null
+++ b/webserver/alpha_business_app/adjustable_fields.py
@@ -0,0 +1,51 @@
+from recommerce.configuration.utils import get_class
+
+from .utils import convert_python_type_to_input_type
+
+
+def get_agent_hyperparameter(agent: str, formdata: dict) -> list:
+
+ # get all fields that are possible for this agent
+ agent_class = get_class(agent)
+ agent_specs = agent_class.get_configurable_fields()
+
+ # we want to keep values already inside the html, so we need to parse existing html
+ parameter_values = _convert_form_to_value_dict(formdata)
+ # convert parameter into special list format for view
+ all_parameter = []
+ for spec in agent_specs:
+ this_parameter = {}
+ this_parameter['name'] = spec[0]
+ this_parameter['input_type'] = convert_python_type_to_input_type(spec[1])
+ this_parameter['prefill'] = _get_value_from_dict(spec[0], parameter_values)
+ all_parameter += [this_parameter]
+ return all_parameter
+
+
+def get_rl_parameter_prefill(prefill: dict, error: dict) -> list:
+ # returns list of dictionaries
+ all_parameter = []
+ for key, value in prefill.items():
+ this_parameter = {}
+ this_parameter['name'] = key
+ this_parameter['prefill'] = value if value else ''
+ this_parameter['error'] = error[key] if error[key] else ''
+ all_parameter += [this_parameter]
+ return all_parameter
+
+
+def _convert_form_to_value_dict(config_form) -> dict:
+ final_values = {}
+ for index in range((len(config_form) - 2) // 2):
+ current_name = config_form[f'formdata[{index}][name]']
+ current_value = config_form[f'formdata[{index}][value]']
+ if 'hyperparameter-rl' in current_name:
+ final_values[current_name.replace('hyperparameter-rl-', '')] = current_value
+ return final_values
+
+
+def _get_value_from_dict(key, value_dict) -> dict:
+ try:
+ return value_dict[key]
+ except KeyError:
+ return ''
diff --git a/webserver/alpha_business_app/buttons.py b/webserver/alpha_business_app/buttons.py
index a31da27d..dc9e9f71 100644
--- a/webserver/alpha_business_app/buttons.py
+++ b/webserver/alpha_business_app/buttons.py
@@ -4,6 +4,7 @@
from recommerce.configuration.config_validation import validate_config
+from .adjustable_fields import get_rl_parameter_prefill
from .config_merger import ConfigMerger
from .config_parser import ConfigFlatDictParser
from .container_parser import parse_response_to_database
@@ -280,6 +281,7 @@ def _prefill(self) -> HttpResponse:
return self._decide_rendering()
merger = ConfigMerger()
final_dict, error_dict = merger.merge_config_objects(post_request['config_id'])
+ final_dict['hyperparameter']['rl'] = get_rl_parameter_prefill(final_dict['hyperparameter']['rl'], error_dict['hyperparameter']['rl'])
# set an id for each agent (necessary for view)
for agent_index in range(len(final_dict['environment']['agents'])):
final_dict['environment']['agents'][agent_index]['display_name'] = 'Agent' if agent_index == 0 else 'Competitor'
diff --git a/webserver/alpha_business_app/models/agents_config.py b/webserver/alpha_business_app/models/agents_config.py
index daee858b..ccb251e7 100644
--- a/webserver/alpha_business_app/models/agents_config.py
+++ b/webserver/alpha_business_app/models/agents_config.py
@@ -7,3 +7,6 @@ class AgentsConfig(AbstractConfig, models.Model):
def as_list(self) -> dict:
referencing_agents = self.agentconfig_set.all()
return [agent.as_dict() for agent in referencing_agents]
+
+ def as_dict(self) -> dict:
+ assert False, 'This should not be implemented as agents are a list.'
diff --git a/webserver/alpha_business_app/models/environment_config.py b/webserver/alpha_business_app/models/environment_config.py
index 6a7502ca..d6f7aa86 100644
--- a/webserver/alpha_business_app/models/environment_config.py
+++ b/webserver/alpha_business_app/models/environment_config.py
@@ -1,5 +1,6 @@
from django.db import models
+from ..utils import remove_none_values_from_dict
from .abstract_config import AbstractConfig
@@ -10,3 +11,14 @@ class EnvironmentConfig(AbstractConfig, models.Model):
plot_interval = models.IntegerField(null=True)
marketplace = models.CharField(max_length=150, null=True)
task = models.CharField(max_length=14, choices=((1, 'training'), (2, 'agent_monitoring'), (3, 'exampleprinter')), null=True)
+
+ def as_dict(self) -> dict:
+ agents_list = self.agents.as_list() if self.agents is not None else None
+ return remove_none_values_from_dict({
+ 'enable_live_draw': self.enable_live_draw,
+ 'episodes': self.episodes,
+ 'plot_interval': self.plot_interval,
+ 'marketplace': self.marketplace,
+ 'task': self.task,
+ 'agents': agents_list
+ })
diff --git a/webserver/alpha_business_app/models/hyperparameter_config.py b/webserver/alpha_business_app/models/hyperparameter_config.py
index 54613459..a8cc7267 100644
--- a/webserver/alpha_business_app/models/hyperparameter_config.py
+++ b/webserver/alpha_business_app/models/hyperparameter_config.py
@@ -1,5 +1,6 @@
from django.db import models
+from ..utils import remove_none_values_from_dict
from .abstract_config import AbstractConfig
from .rl_config import RlConfig
from .sim_market_config import SimMarketConfig
@@ -8,3 +9,11 @@
class HyperparameterConfig(AbstractConfig, models.Model):
rl = models.ForeignKey('alpha_business_app.RLConfig', on_delete=models.CASCADE, null=True)
sim_market = models.ForeignKey('alpha_business_app.SimMarketConfig', on_delete=models.CASCADE, null=True)
+
+ def as_dict(self) -> dict:
+ sim_market_dict = self.sim_market.as_dict() if self.sim_market is not None else {'sim_market': None}
+ rl_dict = self.rl.as_dict() if self.rl is not None else {'rl': None}
+ return remove_none_values_from_dict({
+ 'rl': rl_dict,
+ 'sim_market': sim_market_dict
+ })
diff --git a/webserver/alpha_business_app/static/js/custom.js b/webserver/alpha_business_app/static/js/custom.js
index dcf575d7..0bf63564 100644
--- a/webserver/alpha_business_app/static/js/custom.js
+++ b/webserver/alpha_business_app/static/js/custom.js
@@ -63,19 +63,20 @@ $(document).ready(function() {
function addChangeToAgent () {
$("select.agent-agent-class").change(function () {
- // will be called when another marketplace has been selected
+ // will be called when agent dropdown has changed, we need to change rl hyperparameter for that
var self = $(this);
- console.log('here')
+ var form = $("form.config-form");
+ var formdata = form.serializeArray();
const csrftoken = getCookie("csrftoken");
$.ajax({
type: "POST",
url: self.data("url"),
data: {
csrfmiddlewaretoken: csrftoken,
- "agent": self.val()
+ "agent": self.val(),
+ formdata
},
success: function (data) {
- console.log(data)
$("div.rl-parameter").empty().append(data)
}
});
diff --git a/webserver/alpha_business_app/tests/constant_tests.py b/webserver/alpha_business_app/tests/constant_tests.py
index 5dc4a0e2..810b5b51 100644
--- a/webserver/alpha_business_app/tests/constant_tests.py
+++ b/webserver/alpha_business_app/tests/constant_tests.py
@@ -142,7 +142,6 @@
EXAMPLE_RL_DICT = {
'rl': {
- 'class': 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent',
'gamma': 0.99,
'batch_size': 32,
'replay_size': 100000,
diff --git a/webserver/alpha_business_app/tests/test_adjustable_fields.py b/webserver/alpha_business_app/tests/test_adjustable_fields.py
new file mode 100644
index 00000000..27a016e3
--- /dev/null
+++ b/webserver/alpha_business_app/tests/test_adjustable_fields.py
@@ -0,0 +1,16 @@
+from django.test import TestCase
+
+from ..adjustable_fields import get_rl_parameter_prefill
+
+
+class AdjustableFieldsTests(TestCase):
+ def test_rl_hyperparameter_with_prefill(self):
+ prefill_dict = {'gamma': 0.9, 'learning_rate': 0.4, 'test': None}
+ error_dict = {'gamma': 'test', 'learning_rate': None, 'test': None}
+ expected_list = [
+ {'name': 'gamma', 'prefill': 0.9, 'error': 'test'},
+ {'name': 'learning_rate', 'prefill': 0.4, 'error': ''},
+ {'name': 'test', 'prefill': '', 'error': ''}
+ ]
+ actual_list = get_rl_parameter_prefill(prefill_dict, error_dict)
+ assert actual_list == expected_list
diff --git a/webserver/alpha_business_app/tests/test_config_model.py b/webserver/alpha_business_app/tests/test_config_model.py
index aa2d4c51..c24988b1 100644
--- a/webserver/alpha_business_app/tests/test_config_model.py
+++ b/webserver/alpha_business_app/tests/test_config_model.py
@@ -5,7 +5,9 @@
from ..models.config import Config
from ..models.container import Container
from ..models.environment_config import EnvironmentConfig
+from ..models.hyperparameter_config import HyperparameterConfig
from ..models.rl_config import RlConfig
+from ..models.sim_market_config import SimMarketConfig
from ..utils import capitalize, remove_none_values_from_dict, to_config_class_name
# from .constant_tests import EMPTY_STRUCTURE_CONFIG
@@ -50,76 +52,76 @@ def test_is_referenced(self):
assert test_config_not_referenced.is_referenced() is False
assert test_config_referenced.is_referenced() is True
- # def test_config_to_dict(self):
- # # create a small valid config for this test
- # agents_config = AgentsConfig.objects.create()
-
- # AgentConfig.objects.create(agent_class='recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent',
- # argument='', agents_config=agents_config, name='Rule_Based Agent')
-
- # env_config = EnvironmentConfig.objects.create(agents=agents_config,
- # marketplace='recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly',
- # task='training')
-
- # rl_config = RlConfig.objects.create(gamma=0.99,
- # batch_size=32,
- # replay_size=100000,
- # learning_rate=1e-6,
- # sync_target_frames=1000,
- # replay_start_size=10000,
- # epsilon_decay_last_frame=75000,
- # epsilon_start=1.0,
- # epsilon_final=0.1)
-
- # sim_market_config = SimMarketConfig.objects.create(max_storage=100,
- # episode_length=50,
- # max_price=10,
- # max_quality=50,
- # number_of_customers=20,
- # production_price=3,
- # storage_cost_per_product=0.1)
-
- # hyperparameter_config = HyperparameterConfig.objects.create(sim_market=sim_market_config,
- # rl=rl_config)
-
- # final_config = Config.objects.create(environment=env_config,
- # hyperparameter=hyperparameter_config)
- # expected_dict = {
- # 'hyperparameter': {
- # 'rl': {
- # 'gamma': 0.99,
- # 'batch_size': 32,
- # 'replay_size': 100000,
- # 'learning_rate': 1e-6,
- # 'sync_target_frames': 1000,
- # 'replay_start_size': 10000,
- # 'epsilon_decay_last_frame': 75000,
- # 'epsilon_start': 1.0,
- # 'epsilon_final': 0.1
- # },
- # 'sim_market': {
- # 'max_storage': 100,
- # 'episode_length': 50,
- # 'max_price': 10,
- # 'max_quality': 50,
- # 'number_of_customers': 20,
- # 'production_price': 3,
- # 'storage_cost_per_product': 0.1
- # }
- # },
- # 'environment': {
- # 'task': 'training',
- # 'marketplace': 'recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly',
- # 'agents': [
- # {
- # 'name': 'Rule_Based Agent',
- # 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent',
- # 'argument': ''
- # }
- # ]
- # }
- # }
- # assert expected_dict == final_config.as_dict()
+ def test_config_to_dict(self):
+ # create a small valid config for this test
+ agents_config = AgentsConfig.objects.create()
+
+ AgentConfig.objects.create(agent_class='recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent',
+ argument='', agents_config=agents_config, name='Rule_Based Agent')
+
+ env_config = EnvironmentConfig.objects.create(agents=agents_config,
+ marketplace='recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly',
+ task='training')
+
+ rl_config = RlConfig.objects.create(gamma=0.99,
+ batch_size=32,
+ replay_size=100000,
+ learning_rate=1e-6,
+ sync_target_frames=1000,
+ replay_start_size=10000,
+ epsilon_decay_last_frame=75000,
+ epsilon_start=1.0,
+ epsilon_final=0.1)
+
+ sim_market_config = SimMarketConfig.objects.create(max_storage=100,
+ episode_length=50,
+ max_price=10,
+ max_quality=50,
+ number_of_customers=20,
+ production_price=3,
+ storage_cost_per_product=0.1)
+
+ hyperparameter_config = HyperparameterConfig.objects.create(sim_market=sim_market_config,
+ rl=rl_config)
+
+ final_config = Config.objects.create(environment=env_config,
+ hyperparameter=hyperparameter_config)
+ expected_dict = {
+ 'hyperparameter': {
+ 'rl': {
+ 'gamma': 0.99,
+ 'batch_size': 32,
+ 'replay_size': 100000,
+ 'learning_rate': 1e-6,
+ 'sync_target_frames': 1000,
+ 'replay_start_size': 10000,
+ 'epsilon_decay_last_frame': 75000,
+ 'epsilon_start': 1.0,
+ 'epsilon_final': 0.1
+ },
+ 'sim_market': {
+ 'max_storage': 100,
+ 'episode_length': 50,
+ 'max_price': 10,
+ 'max_quality': 50,
+ 'number_of_customers': 20,
+ 'production_price': 3,
+ 'storage_cost_per_product': 0.1
+ }
+ },
+ 'environment': {
+ 'task': 'training',
+ 'marketplace': 'recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly',
+ 'agents': [
+ {
+ 'name': 'Rule_Based Agent',
+ 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent',
+ 'argument': ''
+ }
+ ]
+ }
+ }
+ assert expected_dict == final_config.as_dict()
def test_dict_representation_of_agent(self):
test_agent = AgentConfig.objects.create(name='test_agent', agent_class='test_class', argument='1234')
@@ -154,12 +156,21 @@ def test_list_representation_of_agents(self):
]
assert expected_list == test_agents.as_list()
- # def test_get_empty_structure_dict(self):
- # actual_dict = Config.get_empty_structure_dict()
- # assert EMPTY_STRUCTURE_CONFIG == actual_dict
-
def test_get_empty_structure_dict_for_rl(self):
- expected_dict = {}
+ expected_dict = {
+ 'sync_target_frames': None,
+ 'testvalue2': None,
+ 'gamma': None,
+ 'epsilon_start': None,
+ 'replay_size': None,
+ 'stable_baseline_test': None,
+ 'threshold': None,
+ 'epsilon_decay_last_frame': None,
+ 'batch_size': None,
+ 'epsilon_final': None,
+ 'replay_start_size': None,
+ 'learning_rate': None
+ }
assert expected_dict == RlConfig.get_empty_structure_dict()
def test_remove_none_values_from_dict(self):
diff --git a/webserver/alpha_business_app/tests/test_config_model_parser.py b/webserver/alpha_business_app/tests/test_config_model_parser.py
index 7c617530..9be3515a 100644
--- a/webserver/alpha_business_app/tests/test_config_model_parser.py
+++ b/webserver/alpha_business_app/tests/test_config_model_parser.py
@@ -4,6 +4,7 @@
from ..models.agents_config import AgentsConfig
from ..models.config import Config
from ..models.environment_config import EnvironmentConfig
+from ..models.hyperparameter_config import HyperparameterConfig
from ..models.rl_config import RlConfig
from ..models.sim_market_config import SimMarketConfig
from .constant_tests import EXAMPLE_HIERARCHY_DICT, EXAMPLE_RL_DICT
@@ -128,13 +129,23 @@ def test_parsing_agents(self):
def test_parse_rl(self):
test_dict = EXAMPLE_RL_DICT.copy()
- final_config = self.parser.parse_config(test_dict)
+ final_config = self.parser.parse_config_dict_to_datastructure('hyperparameter', test_dict)
- assert Config == type(final_config)
+ assert HyperparameterConfig == type(final_config)
# assert all hyperparameters
- hyperparameter_rl_config: RlConfig = final_config.hyperparameter.rl
- hyperparameter_sim_market_config: SimMarketConfig = final_config.hyperparameter.sim_market
-
+ hyperparameter_rl_config: RlConfig = final_config.rl
+ hyperparameter_sim_market_config: SimMarketConfig = final_config.sim_market
assert hyperparameter_rl_config is not None
assert hyperparameter_sim_market_config is None
+
+ assert 0.99 == hyperparameter_rl_config.gamma
+ assert 32 == hyperparameter_rl_config.batch_size
+ assert 100000 == hyperparameter_rl_config.replay_size
+ assert 1e-6 == hyperparameter_rl_config.learning_rate
+ assert 1000 == hyperparameter_rl_config.sync_target_frames
+ assert 10000 == hyperparameter_rl_config.replay_start_size
+ assert 75000 == hyperparameter_rl_config.epsilon_decay_last_frame
+ assert 1.0 == hyperparameter_rl_config.epsilon_start
+ assert 0.1 == hyperparameter_rl_config.epsilon_final
+ assert hyperparameter_rl_config.stable_baseline_test is None
diff --git a/webserver/alpha_business_app/utils.py b/webserver/alpha_business_app/utils.py
index 54e98d6c..c2dc52ec 100644
--- a/webserver/alpha_business_app/utils.py
+++ b/webserver/alpha_business_app/utils.py
@@ -28,20 +28,7 @@ def get_recommerce_agents_for_marketplace(marketplace) -> list:
return marketplace.get_possible_rl_agents()
-def get_agent_hyperparameter(agent: str) -> list:
- agent_class = get_class(agent)
- agent_specs = agent_class.get_configurable_fields()
- # name, input_type, error
- all_parameter = []
- for spec in agent_specs:
- this_parameter = {}
- this_parameter['name'] = spec[0]
- this_parameter['input_type'] = _convert_python_type_to_input_type(spec[1])
- all_parameter += [this_parameter]
- return all_parameter
-
-
-def _convert_python_type_to_input_type(to_convert) -> str:
+def convert_python_type_to_input_type(to_convert) -> str:
return 'number' if to_convert == float or to_convert == int else 'text'
diff --git a/webserver/alpha_business_app/views.py b/webserver/alpha_business_app/views.py
index 2b107fe4..80a5e66a 100644
--- a/webserver/alpha_business_app/views.py
+++ b/webserver/alpha_business_app/views.py
@@ -6,6 +6,7 @@
from recommerce.configuration.config_validation import validate_config
+from .adjustable_fields import get_agent_hyperparameter
from .buttons import ButtonHandler
from .config_parser import ConfigFlatDictParser
from .forms import UploadFileForm
@@ -14,7 +15,6 @@
from .models.config import Config
from .models.container import Container
from .selection_manager import SelectionManager
-from .utils import get_agent_hyperparameter
selection_manager = SelectionManager()
@@ -106,9 +106,9 @@ def new_agent(request) -> HttpResponse:
def agent_changed(request) -> HttpResponse:
if not request.user.is_authenticated:
return HttpResponse('Unauthorized', status=401)
- print(get_agent_hyperparameter(request.POST['agent']))
+ # print(request.POST)
return render(request, 'configuration_items/rl_parameter.html',
- {'parameters': get_agent_hyperparameter(request.POST['agent'])})
+ {'parameters': get_agent_hyperparameter(request.POST['agent'], request.POST.dict())})
def api_availability(request) -> HttpResponse:
diff --git a/webserver/templates/configuration_items/rl.html b/webserver/templates/configuration_items/rl.html
index ebab9862..cc2f781a 100644
--- a/webserver/templates/configuration_items/rl.html
+++ b/webserver/templates/configuration_items/rl.html
@@ -6,115 +6,7 @@
- {% load static %}
-
-
- {% if error_dict.gamma %}
-
![{{error_dict.gamma}}]({% static 'icons/warning.svg' %})
- {% endif %}
- gamma
-
-
-
-
-
-
-
- {% if error_dict.batch_size %}
-
![{{error_dict.batch_size}}]({% static 'icons/warning.svg' %})
- {% endif %}
- batch size
-
-
-
-
-
-
-
- {% if error_dict.replay_size %}
-
![{{error_dict.replay_size}}]({% static 'icons/warning.svg' %})
- {% endif %}
- replay size
-
-
-
-
-
-
-
- {% if error_dict.learning_rate %}
-
![{{error_dict.learning_rate}}]({% static 'icons/warning.svg' %})
- {% endif %}
- learning rate
-
-
-
-
-
-
-
- {% if error_dict.sync_target_frames %}
-
![{{error_dict.sync_target_frames}}]({% static 'icons/warning.svg' %})
- {% endif %}
- sync target frames
-
-
-
-
-
-
-
- {% if error_dict.start_size %}
-
![{{error_dict.start_size}}]({% static 'icons/warning.svg' %})
- {% endif %}
- replay start size
-
-
-
-
-
-
-
- {% if error_dict.epsilon_decay_last_frame %}
-
![{{error_dict.epsilon_decay_last_frame}}]({% static 'icons/warning.svg' %})
- {% endif %}
- epsilon decay last frame
-
-
-
-
-
-
-
- {% if error_dict.epsilon_start %}
-
![{{error_dict.epsilon_start}}]({% static 'icons/warning.svg' %})
- {% endif %}
- epsilon start
-
-
-
-
-
-
-
- {% if error_dict.epsilon_final %}
-
![{{error_dict.epsilon_final}}]({% static 'icons/warning.svg' %})
- {% endif %}
- epsilon final
-
-
-
-
-
+ {% include "configuration_items/rl_parameter.html" with parameters=prefill %}
From 135e8f300a5da0f853f07ddc4faec7ae3a13d0e9 Mon Sep 17 00:00:00 2001
From: Judith <39854388+felix-20@users.noreply.github.com>
Date: Wed, 1 Jun 2022 20:52:12 +0200
Subject: [PATCH 2/3] fix webserver tests
---
recommerce/configuration/config_validation.py | 2 +-
tests/test_vendors.py | 1 -
webserver/alpha_business_app/buttons.py | 2 +-
webserver/alpha_business_app/config_merger.py | 7 +-
...eline_test_rlconfig_testvalue2_and_more.py | 68 +++++++++++++++++++
.../alpha_business_app/models/rl_config.py | 11 ++-
.../tests/constant_tests.py | 12 ++--
.../alpha_business_app/tests/test_prefill.py | 21 ++++--
.../alpha_business_app/tests/test_utils.py | 6 +-
9 files changed, 102 insertions(+), 28 deletions(-)
create mode 100644 webserver/alpha_business_app/migrations/0013_rlconfig_stable_baseline_test_rlconfig_testvalue2_and_more.py
diff --git a/recommerce/configuration/config_validation.py b/recommerce/configuration/config_validation.py
index 692cd0cc..8dcd18b2 100644
--- a/recommerce/configuration/config_validation.py
+++ b/recommerce/configuration/config_validation.py
@@ -34,7 +34,7 @@ def validate_config(config: dict, config_is_final: bool) -> tuple:
if 'marketplace' in environment_config:
market_class = get_class(environment_config['marketplace'])
# the first agent is always the relevant one
- if 'agents' in environment_config and len(environment_config['agents'] >= 1):
+ if 'agents' in environment_config and len(environment_config['agents']) >= 1:
agent_class = get_class(environment_config['agents'][0]['agent_class'])
# validate that all given values have the correct types
diff --git a/tests/test_vendors.py b/tests/test_vendors.py
index 4cd5ed80..f8d65dbe 100644
--- a/tests/test_vendors.py
+++ b/tests/test_vendors.py
@@ -103,7 +103,6 @@ def test_storage_evaluation_with_rebuy_price(state, expected_prices):
changed_config.max_price = 10
changed_config.production_price = 2
agent = circular_vendors.RuleBasedCERebuyAgent(config_market=changed_config)
- print('*********************************')
assert expected_prices == agent.policy(state)
diff --git a/webserver/alpha_business_app/buttons.py b/webserver/alpha_business_app/buttons.py
index dc9e9f71..46c72e46 100644
--- a/webserver/alpha_business_app/buttons.py
+++ b/webserver/alpha_business_app/buttons.py
@@ -280,7 +280,7 @@ def _prefill(self) -> HttpResponse:
if 'config_id' not in post_request:
return self._decide_rendering()
merger = ConfigMerger()
- final_dict, error_dict = merger.merge_config_objects(post_request['config_id'])
+ final_dict, error_dict = merger.merge_config_objects(post_request['config_id'], post_request)
final_dict['hyperparameter']['rl'] = get_rl_parameter_prefill(final_dict['hyperparameter']['rl'], error_dict['hyperparameter']['rl'])
# set an id for each agent (necessary for view)
for agent_index in range(len(final_dict['environment']['agents'])):
diff --git a/webserver/alpha_business_app/config_merger.py b/webserver/alpha_business_app/config_merger.py
index fc27e511..4a239d55 100644
--- a/webserver/alpha_business_app/config_merger.py
+++ b/webserver/alpha_business_app/config_merger.py
@@ -17,14 +17,13 @@ def merge_config_objects(self, config_object_ids: list) -> tuple:
"""
configuration_objects = [Config.objects.get(id=config_id) for config_id in config_object_ids]
configuration_dicts = [config.as_dict() for config in configuration_objects]
- for c in configuration_dicts:
- print(c)
- print('--------------------------------------')
+ # for c in configuration_dicts:
+ # print(c)
+ # print('--------------------------------------')
# get initial empty dict to merge into
final_config = Config.get_empty_structure_dict()
for config in configuration_dicts:
final_config = self._merge_config_into_base_config(final_config, config)
- print(final_config, '\n*******************************\n')
return final_config, self.error_dict
def _merge_config_into_base_config(self, base_config: dict, merging_config: dict, current_config_path: str = '') -> dict:
diff --git a/webserver/alpha_business_app/migrations/0013_rlconfig_stable_baseline_test_rlconfig_testvalue2_and_more.py b/webserver/alpha_business_app/migrations/0013_rlconfig_stable_baseline_test_rlconfig_testvalue2_and_more.py
new file mode 100644
index 00000000..823fa036
--- /dev/null
+++ b/webserver/alpha_business_app/migrations/0013_rlconfig_stable_baseline_test_rlconfig_testvalue2_and_more.py
@@ -0,0 +1,68 @@
+# Generated by Django 4.0.1 on 2022-06-01 18:36
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('alpha_business_app', '0012_config_user_container_user'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='rlconfig',
+ name='stable_baseline_test',
+ field=models.FloatField(default=None, null=True),
+ ),
+ migrations.AddField(
+ model_name='rlconfig',
+ name='testvalue2',
+ field=models.FloatField(default=None, null=True),
+ ),
+ migrations.AlterField(
+ model_name='rlconfig',
+ name='batch_size',
+ field=models.IntegerField(default=None, null=True),
+ ),
+ migrations.AlterField(
+ model_name='rlconfig',
+ name='epsilon_decay_last_frame',
+ field=models.IntegerField(default=None, null=True),
+ ),
+ migrations.AlterField(
+ model_name='rlconfig',
+ name='epsilon_final',
+ field=models.FloatField(default=None, null=True),
+ ),
+ migrations.AlterField(
+ model_name='rlconfig',
+ name='epsilon_start',
+ field=models.FloatField(default=None, null=True),
+ ),
+ migrations.AlterField(
+ model_name='rlconfig',
+ name='gamma',
+ field=models.FloatField(default=None, null=True),
+ ),
+ migrations.AlterField(
+ model_name='rlconfig',
+ name='learning_rate',
+ field=models.FloatField(default=None, null=True),
+ ),
+ migrations.AlterField(
+ model_name='rlconfig',
+ name='replay_size',
+ field=models.IntegerField(default=None, null=True),
+ ),
+ migrations.AlterField(
+ model_name='rlconfig',
+ name='replay_start_size',
+ field=models.IntegerField(default=None, null=True),
+ ),
+ migrations.AlterField(
+ model_name='rlconfig',
+ name='sync_target_frames',
+ field=models.IntegerField(default=None, null=True),
+ ),
+ ]
diff --git a/webserver/alpha_business_app/models/rl_config.py b/webserver/alpha_business_app/models/rl_config.py
index e66fcdf6..d281f39a 100644
--- a/webserver/alpha_business_app/models/rl_config.py
+++ b/webserver/alpha_business_app/models/rl_config.py
@@ -4,15 +4,14 @@
class RlConfig(AbstractConfig, models.Model):
+ stable_baseline_test = models.FloatField(null=True, default=None)
+ epsilon_decay_last_frame = models.IntegerField(null=True, default=None)
epsilon_start = models.FloatField(null=True, default=None)
sync_target_frames = models.IntegerField(null=True, default=None)
replay_start_size = models.IntegerField(null=True, default=None)
+ testvalue2 = models.FloatField(null=True, default=None)
+ gamma = models.FloatField(null=True, default=None)
+ epsilon_final = models.FloatField(null=True, default=None)
replay_size = models.IntegerField(null=True, default=None)
batch_size = models.IntegerField(null=True, default=None)
- stable_baseline_test = models.FloatField(null=True, default=None)
learning_rate = models.FloatField(null=True, default=None)
- epsilon_decay_last_frame = models.IntegerField(null=True, default=None)
- testvalue2 = models.FloatField(null=True, default=None)
- epsilon_final = models.FloatField(null=True, default=None)
- threshold = models.FloatField(null=True, default=None)
- gamma = models.FloatField(null=True, default=None)
diff --git a/webserver/alpha_business_app/tests/constant_tests.py b/webserver/alpha_business_app/tests/constant_tests.py
index 810b5b51..d636024a 100644
--- a/webserver/alpha_business_app/tests/constant_tests.py
+++ b/webserver/alpha_business_app/tests/constant_tests.py
@@ -7,8 +7,8 @@
'environment-episodes': [''],
'environment-plot_interval': [''],
'environment-marketplace': ['recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly'],
- 'environment-agents-name': ['Rule_Based Agent'],
- 'environment-agents-agent_class': ['recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent'],
+ 'environment-agents-name': ['QLearning Agent'],
+ 'environment-agents-agent_class': ['recommerce.rl.q_learning.q_learning_agent.QLearningAgent'],
'environment-agents-argument': [''],
'hyperparameter-rl-gamma': ['0.99'],
'hyperparameter-rl-batch_size': ['32'],
@@ -35,8 +35,8 @@
'enable_live_draw': False,
'agents': [
{
- 'name': 'Rule_Based Agent',
- 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent',
+ 'name': 'QLearning Agent',
+ 'agent_class': 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent',
'argument': ''
}
]
@@ -126,7 +126,9 @@
'replay_start_size': None,
'epsilon_decay_last_frame': None,
'epsilon_start': None,
- 'epsilon_final': None
+ 'epsilon_final': None,
+ 'testvalue2': None,
+ 'stable_baseline_test': None
},
'sim_market': {
'max_storage': None,
diff --git a/webserver/alpha_business_app/tests/test_prefill.py b/webserver/alpha_business_app/tests/test_prefill.py
index 7e0f215a..aeb9b92a 100644
--- a/webserver/alpha_business_app/tests/test_prefill.py
+++ b/webserver/alpha_business_app/tests/test_prefill.py
@@ -8,7 +8,10 @@
# from ..buttons import ButtonHandler
from ..config_merger import ConfigMerger
from ..config_parser import ConfigModelParser
-from ..models.config import Config, EnvironmentConfig, HyperparameterConfig, RlConfig
+from ..models.config import Config
+from ..models.environment_config import EnvironmentConfig
+from ..models.hyperparameter_config import HyperparameterConfig
+from ..models.rl_config import RlConfig
from .constant_tests import EMPTY_STRUCTURE_CONFIG, EXAMPLE_HIERARCHY_DICT, EXAMPLE_HIERARCHY_DICT2
# from unittest.mock import patch
@@ -35,6 +38,9 @@ def test_merge_one_config(self):
expected_dict = copy.deepcopy(config_dict)
expected_dict['environment']['episodes'] = None
expected_dict['environment']['plot_interval'] = None
+ expected_dict['hyperparameter']['rl']['testvalue2'] = None
+ expected_dict['hyperparameter']['rl']['stable_baseline_test'] = None
+
empty_config = Config.get_empty_structure_dict()
merger = ConfigMerger()
actual_config = merger._merge_config_into_base_config(empty_config, config_dict)
@@ -98,8 +104,8 @@ def test_merge_two_configs_with_conflicts(self):
'task': 'monitoring',
'agents': [
{
- 'name': 'Rule_Based Agent',
- 'agent_class': 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent',
+ 'name': 'QLearning Agent',
+ 'agent_class': 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent',
'argument': ''
},
{
@@ -122,7 +128,10 @@ def test_merge_two_configs_with_conflicts(self):
'sync_target_frames': 100,
'replay_start_size': 1000,
'epsilon_decay_last_frame': 7500,
- 'epsilon_start': 0.9, 'epsilon_final': 0.2
+ 'epsilon_start': 0.9,
+ 'epsilon_final': 0.2,
+ 'testvalue2': None,
+ 'stable_baseline_test': None
},
'sim_market': {
'max_storage': 80,
@@ -153,7 +162,9 @@ def test_merge_two_configs_with_conflicts(self):
'replay_start_size': 'changed hyperparameter-rl replay_start_size from 10000 to 1000',
'epsilon_decay_last_frame': 'changed hyperparameter-rl epsilon_decay_last_frame from 75000 to 7500',
'epsilon_start': 'changed hyperparameter-rl epsilon_start from 1.0 to 0.9',
- 'epsilon_final': 'changed hyperparameter-rl epsilon_final from 0.1 to 0.2'
+ 'epsilon_final': 'changed hyperparameter-rl epsilon_final from 0.1 to 0.2',
+ 'testvalue2': None,
+ 'stable_baseline_test': None
},
'sim_market': {
'max_storage': 'changed hyperparameter-sim_market max_storage from 100 to 80',
diff --git a/webserver/alpha_business_app/tests/test_utils.py b/webserver/alpha_business_app/tests/test_utils.py
index 4da233c8..21345127 100644
--- a/webserver/alpha_business_app/tests/test_utils.py
+++ b/webserver/alpha_business_app/tests/test_utils.py
@@ -4,7 +4,7 @@
class UtilsTest(TestCase):
- def test_get_structure_dict_for(self):
+ def test_get_structure_dict_for_config(self):
expected_dict = {
'environment': {
'task': None,
@@ -16,7 +16,6 @@ def test_get_structure_dict_for(self):
},
'hyperparameter': {
'sim_market': {
- 'class': None,
'max_storage': None,
'episode_length': None,
'max_price': None,
@@ -33,7 +32,6 @@ def test_get_structure_dict_for(self):
'testvalue2': None,
'sync_target_frames': None,
'batch_size': None,
- 'threshold': None,
'epsilon_final': None,
'stable_baseline_test': None,
'gamma': None,
@@ -52,7 +50,6 @@ def test_get_structure_dict_for_rl(self):
'testvalue2': None,
'sync_target_frames': None,
'batch_size': None,
- 'threshold': None,
'epsilon_final': None,
'stable_baseline_test': None,
'gamma': None,
@@ -62,7 +59,6 @@ def test_get_structure_dict_for_rl(self):
def test_get_structure_dict_for_sim_market(self):
expected_dict = {
- 'class': None,
'max_storage': None,
'episode_length': None,
'max_price': None,
From f50fe65fccbed1db759c42ffb6e3afd0e277fe4a Mon Sep 17 00:00:00 2001
From: Judith <39854388+felix-20@users.noreply.github.com>
Date: Thu, 2 Jun 2022 11:21:10 +0200
Subject: [PATCH 3/3] implement new config validation for webserver
---
recommerce/configuration/config_validation.py | 2 +
webserver/alpha_business_app/handle_files.py | 30 ++++++---
.../tests/test_config_flat_dict_parser.py | 4 +-
.../tests/test_config_model.py | 1 -
.../tests/test_config_model_parser.py | 4 +-
.../test_data/test_environment_config.json | 5 +-
.../test_data/test_hyperparameter_config.json | 24 -------
.../tests/test_data/test_mixed_config.json | 15 -----
.../tests/test_data/test_rl_config.json | 12 ++++
.../test_data/test_sim_market_config.json | 11 ++++
.../tests/test_file_handling.py | 66 +++++++++----------
11 files changed, 84 insertions(+), 90 deletions(-)
delete mode 100644 webserver/alpha_business_app/tests/test_data/test_hyperparameter_config.json
delete mode 100644 webserver/alpha_business_app/tests/test_data/test_mixed_config.json
create mode 100644 webserver/alpha_business_app/tests/test_data/test_rl_config.json
create mode 100644 webserver/alpha_business_app/tests/test_data/test_sim_market_config.json
diff --git a/recommerce/configuration/config_validation.py b/recommerce/configuration/config_validation.py
index 8dcd18b2..96de3a05 100644
--- a/recommerce/configuration/config_validation.py
+++ b/recommerce/configuration/config_validation.py
@@ -31,6 +31,8 @@ def validate_config(config: dict, config_is_final: bool) -> tuple:
# try to split the config. If any keys are unknown, an AssertionError will be thrown
hyperparameter_config, environment_config = split_mixed_config(config)
+ agent_class = None
+ market_class = None
if 'marketplace' in environment_config:
market_class = get_class(environment_config['marketplace'])
# the first agent is always the relevant one
diff --git a/webserver/alpha_business_app/handle_files.py b/webserver/alpha_business_app/handle_files.py
index c7d095b6..6a523135 100644
--- a/webserver/alpha_business_app/handle_files.py
+++ b/webserver/alpha_business_app/handle_files.py
@@ -56,23 +56,37 @@ def handle_uploaded_file(request, uploaded_config) -> HttpResponse:
except ValueError as value:
return render(request, 'upload.html', {'error': str(value)})
+ # Validate the config file using the recommerce validation functionality
validate_status, validate_data = validate_config(content_as_dict, False)
if not validate_status:
return render(request, 'upload.html', {'error': validate_data})
- hyperparameter_config, environment_config = validate_data
+ config = validate_data
+ assert len(config.keys()) == 1, f'This config ({config} as multiple keys, should only be one ("environment", "rl" or "sim_market"))'
+ # recommerce returns either {'environment': {}}, {'rl': {}} or {'sim_markte': {}}}
+ # for parsing we need to know what the toplevel key is
+ top_level = config.keys()[0]
+ if top_level != 'environment':
+ top_level = 'hyperparameter'
+
+ # parse config model to datastructure
parser = ConfigModelParser()
- web_hyperparameter_config = None
- web_environment_config = None
try:
- web_hyperparameter_config = parser.parse_config_dict_to_datastructure('hyperparameter', hyperparameter_config)
- web_environment_config = parser.parse_config_dict_to_datastructure('environment', environment_config)
- except ValueError:
- return render(request, 'upload.html', {'error': 'Your config is wrong'})
+ resulting_config_part = parser.parse_config_dict_to_datastructure(top_level, config)
+ except ValueError as e:
+ return render(request, 'upload.html', {'error': f'Your config is wrong {e}'})
+
+ # Make it a real config object
+ environment_config = None
+ hyperparameter_config = None
+ if top_level == 'environment':
+ environment_config = resulting_config_part
+ else:
+ hyperparameter_config = resulting_config_part
given_name = request.POST['config_name']
config_name = given_name if given_name else uploaded_config.name
- Config.objects.create(environment=web_environment_config, hyperparameter=web_hyperparameter_config, name=config_name, user=request.user)
+ Config.objects.create(environment=environment_config, hyperparameter=hyperparameter_config, name=config_name, user=request.user)
return redirect('/configurator', {'success': 'You successfully uploaded a config file'})
diff --git a/webserver/alpha_business_app/tests/test_config_flat_dict_parser.py b/webserver/alpha_business_app/tests/test_config_flat_dict_parser.py
index 58d52fc8..cd150967 100644
--- a/webserver/alpha_business_app/tests/test_config_flat_dict_parser.py
+++ b/webserver/alpha_business_app/tests/test_config_flat_dict_parser.py
@@ -209,8 +209,8 @@ def test_parsing_config_dict(self):
all_agents = environment_agents.agentconfig_set.all()
assert 1 == len(all_agents)
- assert 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent' == all_agents[0].agent_class
- assert 'Rule_Based Agent' == all_agents[0].name
+ assert 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent' == all_agents[0].agent_class
+ assert 'QLearning Agent' == all_agents[0].name
assert '' == all_agents[0].argument
def test_parsing_agents(self):
diff --git a/webserver/alpha_business_app/tests/test_config_model.py b/webserver/alpha_business_app/tests/test_config_model.py
index c24988b1..da8287f7 100644
--- a/webserver/alpha_business_app/tests/test_config_model.py
+++ b/webserver/alpha_business_app/tests/test_config_model.py
@@ -164,7 +164,6 @@ def test_get_empty_structure_dict_for_rl(self):
'epsilon_start': None,
'replay_size': None,
'stable_baseline_test': None,
- 'threshold': None,
'epsilon_decay_last_frame': None,
'batch_size': None,
'epsilon_final': None,
diff --git a/webserver/alpha_business_app/tests/test_config_model_parser.py b/webserver/alpha_business_app/tests/test_config_model_parser.py
index 9be3515a..4132894e 100644
--- a/webserver/alpha_business_app/tests/test_config_model_parser.py
+++ b/webserver/alpha_business_app/tests/test_config_model_parser.py
@@ -98,8 +98,8 @@ def test_parsing_config_dict(self):
all_agents = environment_agents.agentconfig_set.all()
assert 1 == len(all_agents)
- assert 'recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent' == all_agents[0].agent_class
- assert 'Rule_Based Agent' == all_agents[0].name
+ assert 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent' == all_agents[0].agent_class
+ assert 'QLearning Agent' == all_agents[0].name
assert '' == all_agents[0].argument
def test_parsing_agents(self):
diff --git a/webserver/alpha_business_app/tests/test_data/test_environment_config.json b/webserver/alpha_business_app/tests/test_data/test_environment_config.json
index edfa9644..b06c4d07 100644
--- a/webserver/alpha_business_app/tests/test_data/test_environment_config.json
+++ b/webserver/alpha_business_app/tests/test_data/test_environment_config.json
@@ -6,7 +6,7 @@
"marketplace": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly",
"agents": [
{
- "name": "Rule_Based Agent",
+ "name": "QLearning Agent",
"agent_class": "recommerce.market.circular.circular_vendors.RuleBasedCERebuyAgent",
"argument": ""
},
@@ -15,5 +15,6 @@
"agent_class": "recommerce.rl.q_learning.q_learning_agent.QLearningAgent",
"argument": "CircularEconomyRebuyPriceMonopoly_QLearningAgent.dat"
}
- ]
+ ],
+ "config_type": "environment"
}
diff --git a/webserver/alpha_business_app/tests/test_data/test_hyperparameter_config.json b/webserver/alpha_business_app/tests/test_data/test_hyperparameter_config.json
deleted file mode 100644
index 11278d8b..00000000
--- a/webserver/alpha_business_app/tests/test_data/test_hyperparameter_config.json
+++ /dev/null
@@ -1,24 +0,0 @@
-{
- "rl": {
- "class" : "recommerce.rl.q_learning.q_learning_agent.QLearningAgent",
- "gamma" : 0.99,
- "batch_size" : 32,
- "replay_size" : 100000,
- "learning_rate" : 1e-6,
- "sync_target_frames" : 1000,
- "replay_start_size" : 10000,
- "epsilon_decay_last_frame" : 75000,
- "epsilon_start" : 1.0,
- "epsilon_final" : 0.1
- },
- "sim_market": {
- "class": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly",
- "max_storage": 100,
- "episode_length": 50,
- "max_price": 10,
- "max_quality": 50,
- "number_of_customers": 20,
- "production_price": 3,
- "storage_cost_per_product": 0.1
- }
-}
diff --git a/webserver/alpha_business_app/tests/test_data/test_mixed_config.json b/webserver/alpha_business_app/tests/test_data/test_mixed_config.json
deleted file mode 100644
index c58239e5..00000000
--- a/webserver/alpha_business_app/tests/test_data/test_mixed_config.json
+++ /dev/null
@@ -1,15 +0,0 @@
-{
- "task": "training",
- "sim_market": {
- "class": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly",
- "max_storage": 100,
- "episode_length": 50
- },
- "enable_live_draw": false,
- "rl": {
- "class" : "recommerce.rl.q_learning.q_learning_agent.QLearningAgent",
- "gamma" : 0.99,
- "batch_size" : 32
- },
- "episodes": 50
-}
diff --git a/webserver/alpha_business_app/tests/test_data/test_rl_config.json b/webserver/alpha_business_app/tests/test_data/test_rl_config.json
new file mode 100644
index 00000000..7d1eea3b
--- /dev/null
+++ b/webserver/alpha_business_app/tests/test_data/test_rl_config.json
@@ -0,0 +1,12 @@
+{
+ "gamma" : 0.99,
+ "batch_size" : 32,
+ "replay_size" : 100000,
+ "learning_rate" : 1e-6,
+ "sync_target_frames" : 1000,
+ "replay_start_size" : 10000,
+ "epsilon_decay_last_frame" : 75000,
+ "epsilon_start" : 1.0,
+ "epsilon_final" : 0.1,
+ "config_type": "rl"
+}
diff --git a/webserver/alpha_business_app/tests/test_data/test_sim_market_config.json b/webserver/alpha_business_app/tests/test_data/test_sim_market_config.json
new file mode 100644
index 00000000..e1642d9b
--- /dev/null
+++ b/webserver/alpha_business_app/tests/test_data/test_sim_market_config.json
@@ -0,0 +1,11 @@
+{
+ "class": "recommerce.market.circular.circular_sim_market.CircularEconomyRebuyPriceMonopoly",
+ "max_storage": 100,
+ "episode_length": 50,
+ "max_price": 10,
+ "max_quality": 50,
+ "number_of_customers": 20,
+ "production_price": 3,
+ "storage_cost_per_product": 0.1,
+ "config_type": "market"
+}
diff --git a/webserver/alpha_business_app/tests/test_file_handling.py b/webserver/alpha_business_app/tests/test_file_handling.py
index 82aaa0b5..780ddda8 100644
--- a/webserver/alpha_business_app/tests/test_file_handling.py
+++ b/webserver/alpha_business_app/tests/test_file_handling.py
@@ -8,7 +8,11 @@
from ..config_parser import ConfigModelParser
from ..handle_files import handle_uploaded_file
-from ..models.config import *
+from ..models.agents_config import AgentsConfig
+from ..models.config import Config
+from ..models.environment_config import EnvironmentConfig
+from ..models.rl_config import RlConfig
+from ..models.sim_market_config import SimMarketConfig
class MockedResponse():
@@ -92,10 +96,10 @@ def test_objects_from_parse_dict(self):
else:
assert 32 == getattr(resulting_config.rl, name)
- def test_parsing_with_only_hyperparameter(self):
+ def test_parsing_with_rl_hyperparameter(self):
# get a test config to be parsed
path_to_test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data')
- with open(os.path.join(path_to_test_data, 'test_hyperparameter_config.json'), 'r') as file:
+ with open(os.path.join(path_to_test_data, 'test_rl_config.json'), 'r') as file:
content = file.read()
# mock uploaded file with test config
test_uploaded_file = MockedUploadedFile('config.json', content.encode())
@@ -108,9 +112,9 @@ def test_parsing_with_only_hyperparameter(self):
assert Config == type(final_config)
assert final_config.environment is None
assert final_config.hyperparameter is not None
+ assert final_config.hyperparameter.sim_market is None
hyperparameter_rl_config: RlConfig = final_config.hyperparameter.rl
- hyperparameter_sim_market_config: SimMarketConfig = final_config.hyperparameter.sim_market
assert hyperparameter_rl_config is not None
assert final_config.hyperparameter.sim_market is not None
@@ -125,6 +129,28 @@ def test_parsing_with_only_hyperparameter(self):
assert 1.0 == hyperparameter_rl_config.epsilon_start
assert 0.1 == hyperparameter_rl_config.epsilon_final
+ def test_parsing_with_sim_market_hyperparameter(self):
+ # get a test config to be parsed
+ path_to_test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data')
+ with open(os.path.join(path_to_test_data, 'test_sim_market_config.json'), 'r') as file:
+ content = file.read()
+ # mock uploaded file with test config
+ test_uploaded_file = MockedUploadedFile('config.json', content.encode())
+ # test method
+ with patch('alpha_business_app.handle_files.redirect') as redirect_mock:
+ handle_uploaded_file(self._setup_request(), test_uploaded_file)
+ redirect_mock.assert_called_once()
+ # assert the datastructure, that should be present afterwards
+ final_config: Config = Config.objects.all().first()
+ assert Config == type(final_config)
+ assert final_config.environment is None
+ assert final_config.hyperparameter is not None
+ assert final_config.hyperparameter.rl is None
+
+ hyperparameter_sim_market_config: SimMarketConfig = final_config.hyperparameter.sim_market
+
+ assert final_config.hyperparameter.sim_market is not None
+
assert 100 == hyperparameter_sim_market_config.max_storage
assert 50 == hyperparameter_sim_market_config.episode_length
assert 10 == hyperparameter_sim_market_config.max_price
@@ -168,38 +194,6 @@ def test_parsing_with_only_environment(self):
assert 'recommerce.rl.q_learning.q_learning_agent.QLearningAgent' == all_agents[1].agent_class
assert 'CircularEconomyRebuyPriceMonopoly_QLearningAgent.dat' == all_agents[1].argument
- def test_parsing_mixed_config(self):
- # get a test config to be parsed
- path_to_test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data')
- with open(os.path.join(path_to_test_data, 'test_mixed_config.json'), 'r') as file:
- content = file.read()
- # mock uploaded file with test config
- test_uploaded_file = MockedUploadedFile('config.json', content.encode())
- # test method
- with patch('alpha_business_app.handle_files.redirect') as redirect_mock:
- handle_uploaded_file(self._setup_request(), test_uploaded_file)
- redirect_mock.assert_called_once()
- # assert the datastructure, that should be present afterwards
- final_config: Config = Config.objects.all().first()
- assert Config == type(final_config)
- assert final_config.environment is not None
- assert final_config.hyperparameter is not None
-
- environment_config: EnvironmentConfig = final_config.environment
- hyperparameter_config: HyperparameterConfig = final_config.hyperparameter
-
- assert 'training' == environment_config.task
- assert environment_config.enable_live_draw is False
- assert 50 == environment_config.episodes
-
- assert hyperparameter_config.sim_market is not None
- assert hyperparameter_config.rl is not None
-
- assert 100 == hyperparameter_config.sim_market.max_storage
- assert 50 == hyperparameter_config.sim_market.episode_length
- assert 0.99 == hyperparameter_config.rl.gamma
- assert 32 == hyperparameter_config.rl.batch_size
-
def test_parsing_invalid_rl_parameters(self):
test_uploaded_file = MockedUploadedFile('config.json', b'{"rl": {"test":"bla"}}')
with patch('alpha_business_app.handle_files.render') as render_mock: