From 024cfeeb6ef4b040b621a8d6ad58201a86395086 Mon Sep 17 00:00:00 2001 From: PimLeerkes Date: Tue, 11 Feb 2025 12:55:10 +0100 Subject: [PATCH 1/3] pgc now also works with integers instead of pgc state objects --- stormvogel/pgc.py | 43 ++++++++++++----------- tests/test_pgc.py | 89 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 109 insertions(+), 23 deletions(-) diff --git a/stormvogel/pgc.py b/stormvogel/pgc.py index 4a4f4d3..fad153d 100644 --- a/stormvogel/pgc.py +++ b/stormvogel/pgc.py @@ -30,11 +30,11 @@ def __eq__(self, other): def build_pgc( - delta, # Callable[[State, Action], list[tuple[float, State]]], - initial_state_pgc: State, + delta, + initial_state_pgc, rewards=None, labels=None, - available_actions=None, # Callable[[State], list[Action]] | None = None, + available_actions=None, modeltype: stormvogel.model.ModelType = stormvogel.model.ModelType.MDP, ) -> stormvogel.model.Model: """ @@ -53,11 +53,7 @@ def build_pgc( # we create the model with the given type and initial state model = stormvogel.model.new_model(modeltype=modeltype, create_initial_state=False) - model.new_state( - labels=["init"], - features=initial_state_pgc.__dict__, - name=str(initial_state_pgc.__dict__), - ) + model.new_state(labels=["init", str(initial_state_pgc)]) # we continue calling delta and adding new states until no states are # left to be checked @@ -92,14 +88,12 @@ def build_pgc( for tuple in tuples: if tuple[1] not in states_seen: states_seen.append(tuple[1]) - new_state = model.new_state( - name=str(tuple[1].__dict__), features=tuple[1].__dict__ - ) + new_state = model.new_state(labels=str(tuple[1])) branch.append((tuple[0], new_state)) states_to_be_visited.append(tuple[1]) else: branch.append( - (tuple[0], model.get_state_by_name(str(tuple[1].__dict__))) + (tuple[0], model.get_states_with_label(str(tuple[1]))[0]) ) if branch != []: transition[stormvogel_action] = stormvogel.model.Branch(branch) @@ -110,21 +104,19 @@ def build_pgc( for tuple in tuples: if tuple[1] not in states_seen: states_seen.append(tuple[1]) - new_state = model.new_state( - name=str(tuple[1].__dict__), features=tuple[1].__dict__ - ) + new_state = model.new_state(labels=str(tuple[1])) branch.append((tuple[0], new_state)) states_to_be_visited.append(tuple[1]) else: branch.append( - (tuple[0], model.get_state_by_name(str(tuple[1].__dict__))) + (tuple[0], model.get_states_with_label(str(tuple[1]))[0]) ) if branch != []: transition[stormvogel.model.EmptyAction] = stormvogel.model.Branch( branch ) - s = model.get_state_by_name(str(state.__dict__)) + s = model.get_states_with_label(str(state))[0] assert s is not None model.add_transitions( s, @@ -146,7 +138,7 @@ def build_pgc( assert available_actions is not None for action in available_actions(state): rewardlist = rewards(state, action) - s = model.get_state_by_name(str(state.__dict__)) + s = model.get_states_with_label(str(state))[0] assert s is not None for index, reward in enumerate(rewardlist): a = model.get_action_with_labels(frozenset(action.labels)) @@ -164,7 +156,7 @@ def build_pgc( for state in states_seen: rewardlist = rewards(state) - s = model.get_state_by_name(str(state.__dict__)) + s = model.get_states_with_label(str(state))[0] assert s is not None for index, reward in enumerate(rewardlist): model.rewards[index].set_state_reward(s, reward) @@ -172,9 +164,20 @@ def build_pgc( # we add the labels if labels is not None: for state in states_seen: - s = model.get_state_by_name(str(state.__dict__)) + s = model.get_states_with_label(str(state))[0] + if "init" in s.labels: + s.labels = ["init"] + else: + s.labels = [] assert s is not None for label in labels(state): s.add_label(label) + else: + for state in states_seen: + s = model.get_states_with_label(str(state))[0] + if "init" in s.labels: + s.labels = ["init"] + else: + s.labels = [] return model diff --git a/tests/test_pgc.py b/tests/test_pgc.py index cf1453f..a46330c 100644 --- a/tests/test_pgc.py +++ b/tests/test_pgc.py @@ -55,9 +55,92 @@ def delta(s: pgc.State, action: pgc.Action): # we build the model in the regular way: regular_model = model.new_mdp(create_initial_state=False) - state1 = regular_model.new_state(labels=["init", "1"], features={"x": 1}) - state2 = regular_model.new_state(labels=["2"], features={"x": 2}) - state0 = regular_model.new_state(labels=["0"], features={"x": 0}) + state1 = regular_model.new_state(labels=["init", "1"]) + state2 = regular_model.new_state(labels=["2"]) + state0 = regular_model.new_state(labels=["0"]) + left = regular_model.new_action(frozenset({"left"})) + right = regular_model.new_action(frozenset({"right"})) + branch12 = model.Branch([(0.5, state1), (0.5, state2)]) + branch10 = model.Branch([(0.5, state1), (0.5, state0)]) + branch01 = model.Branch([(0.5, state0), (0.5, state1)]) + branch21 = model.Branch([(0.5, state2), (0.5, state1)]) + + regular_model.add_transitions( + state1, model.Transition({left: branch12, right: branch10}) + ) + regular_model.add_transitions(state2, model.Transition({right: branch21})) + regular_model.add_transitions(state0, model.Transition({left: branch01})) + + rewardmodel = regular_model.add_rewards("rewardmodel: 0") + for i in range(2 * N): + pair = regular_model.get_state_action_pair(i) + assert pair is not None + rewardmodel.set_state_action_reward(pair[0], pair[1], 1) + rewardmodel = regular_model.add_rewards("rewardmodel: 1") + for i in range(2 * N): + pair = regular_model.get_state_action_pair(i) + assert pair is not None + rewardmodel.set_state_action_reward(pair[0], pair[1], 2) + + assert regular_model == pgc_model + + +def test_pgc_mdp_int(): + # we build the model with pgc: + N = 2 + p = 0.5 + initial_state = math.floor(N / 2) + + left = pgc.Action(["left"]) + right = pgc.Action(["right"]) + + def available_actions(s): + if s == 1: + return [left, right] + elif s == 2: + return [right] + else: + return [left] + + def rewards(s, a: pgc.Action): + return [1, 2] + + def labels(s): + return [str(s)] + + def delta(s, action: pgc.Action): + if action == left: + return ( + [ + (p, s + 1), + (1 - p, s), + ] + if s < N + else [] + ) + elif action == right: + return ( + [ + (p, s - 1), + (1 - p, s), + ] + if s > 0 + else [] + ) + + pgc_model = pgc.build_pgc( + delta=delta, + available_actions=available_actions, + initial_state_pgc=initial_state, + labels=labels, + rewards=rewards, + ) + + # we build the model in the regular way: + regular_model = model.new_mdp(create_initial_state=False) + state1 = regular_model.new_state(labels=["init", "1"]) + state2 = regular_model.new_state(labels=["2"]) + state0 = regular_model.new_state(labels=["0"]) left = regular_model.new_action(frozenset({"left"})) right = regular_model.new_action(frozenset({"right"})) branch12 = model.Branch([(0.5, state1), (0.5, state2)]) From 80277e983ad33f9d755ee72752fad2b71c9a60fc Mon Sep 17 00:00:00 2001 From: PimLeerkes Date: Tue, 11 Feb 2025 13:09:49 +0100 Subject: [PATCH 2/3] example with strings now also works --- tests/test_pgc.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_pgc.py b/tests/test_pgc.py index a46330c..a19bf1f 100644 --- a/tests/test_pgc.py +++ b/tests/test_pgc.py @@ -291,3 +291,26 @@ def delta(s: pgc.State): rewardmodel.set_state_reward(state, 2) assert pgc_model == regular_model + + +def test_pgc_dtmc_string(): + def delta(current_state): + match current_state: + case "hungry": + return [(1.0, "eating")] + case "eating": + return [(1.0, "hungry")] + + pgc_model = pgc.build_pgc( + delta, initial_state_pgc="hungry", modeltype=model.ModelType.DTMC + ) + + regular_model = model.new_dtmc() + regular_model.set_transitions( + regular_model.get_initial_state(), [(1, regular_model.new_state())] + ) + regular_model.set_transitions( + regular_model.get_state_by_id(1), [(1, regular_model.get_initial_state())] + ) + + assert pgc_model == regular_model From b142fa6010eb66e71491177c642c25de31457d85 Mon Sep 17 00:00:00 2001 From: PimLeerkes Date: Tue, 11 Feb 2025 14:45:41 +0100 Subject: [PATCH 3/3] made a pgc tutorial notebook --- docs/getting_started/04_mdp.ipynb | 2 +- docs/getting_started/05_simulator.ipynb | 1345 +++++++++++++++-- docs/getting_started/07_pomdp.ipynb | 2 +- .../09_pgc_model_builder.ipynb | 317 ++++ docs/getting_started/model.html | 1339 +++++++--------- tests/test_pgc.py | 18 +- 6 files changed, 2106 insertions(+), 917 deletions(-) create mode 100644 docs/getting_started/09_pgc_model_builder.ipynb diff --git a/docs/getting_started/04_mdp.ipynb b/docs/getting_started/04_mdp.ipynb index 76a823e..4358f05 100644 --- a/docs/getting_started/04_mdp.ipynb +++ b/docs/getting_started/04_mdp.ipynb @@ -411,7 +411,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/docs/getting_started/05_simulator.ipynb b/docs/getting_started/05_simulator.ipynb index d97e509..bf11b10 100644 --- a/docs/getting_started/05_simulator.ipynb +++ b/docs/getting_started/05_simulator.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "id": "a8ddc37c-66d2-43e4-8162-6be19a1d70a1", "metadata": {}, "outputs": [], @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "id": "cab40f99-3460-4497-8b9f-3d669eee1e11", "metadata": {}, "outputs": [], @@ -104,111 +104,1248 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "id": "c129cf62-40ca-4246-8718-5c859744e7f8", "metadata": { "scrolled": true }, "outputs": [ { - "ename": "TypeError", - "evalue": "argument of type 'bool' is not iterable", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m vis \u001b[38;5;241m=\u001b[39m show(mdp, layout\u001b[38;5;241m=\u001b[39m\u001b[43mLayout\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlayouts/monty.json\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m, save_and_embed\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/layout.py:29\u001b[0m, in \u001b[0;36mLayout.__init__\u001b[0;34m(self, path, path_relative)\u001b[0m\n\u001b[1;32m 27\u001b[0m default_str \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mread()\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefault_dict: \u001b[38;5;28mdict\u001b[39m \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(default_str)\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpath_relative\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/layout.py:43\u001b[0m, in \u001b[0;36mLayout.load\u001b[0;34m(self, path, path_relative)\u001b[0m\n\u001b[1;32m 41\u001b[0m parsed_dict \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(parsed_str)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# Combine the parsed dict with default to fill missing keys as default values.\u001b[39;00m\n\u001b[0;32m---> 43\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayout: \u001b[38;5;28mdict\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mstormvogel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrdict\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmerge_dict\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 44\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdefault_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparsed_dict\u001b[49m\n\u001b[1;32m 45\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;66;03m# Load in schema for the dict_editor.\u001b[39;00m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(PACKAGE_ROOT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlayouts/schema.json\u001b[39m\u001b[38;5;124m\"\u001b[39m)) \u001b[38;5;28;01mas\u001b[39;00m f:\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/rdict.py:46\u001b[0m, in \u001b[0;36mmerge_dict\u001b[0;34m(dict1, dict2)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(val, \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m dict2 \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(dict2[key] \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mdict\u001b[39m):\n\u001b[0;32m---> 46\u001b[0m \u001b[43mmerge_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdict1\u001b[49m\u001b[43m[\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdict2\u001b[49m\u001b[43m[\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m dict2:\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/rdict.py:48\u001b[0m, in \u001b[0;36mmerge_dict\u001b[0;34m(dict1, dict2)\u001b[0m\n\u001b[1;32m 46\u001b[0m merge_dict(dict1[key], dict2[key])\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mkey\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdict2\u001b[49m:\n\u001b[1;32m 49\u001b[0m dict1[key] \u001b[38;5;241m=\u001b[39m dict2[key]\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, val \u001b[38;5;129;01min\u001b[39;00m dict2\u001b[38;5;241m.\u001b[39mitems():\n", - "\u001b[0;31mTypeError\u001b[0m: argument of type 'bool' is not iterable" - ] - } - ], - "source": [ - "vis = show(mdp, layout=Layout(\"layouts/monty.json\"), save_and_embed=True)" - ] - }, - { - "cell_type": "markdown", - "id": "b5b2990c-65ed-4d7b-a4b8-f303843622e5", - "metadata": {}, - "source": [ - "We want to simulate this model. That is, we start at the initial state and then we walk through the model by choosing random actions.\n", - "\n", - "When we do this, we get a partial model as a result that contains everything we discovered during this walk. \n", - "\n", - "Try running this multiple times, and observe that sometimes we get to the target and sometimes we do not." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "eb0fadc0-7bb6-4c1d-ae3e-9e16527726ab", - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "argument of type 'bool' is not iterable", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[4], line 12\u001b[0m\n\u001b[1;32m 8\u001b[0m partial_model \u001b[38;5;241m=\u001b[39m stormvogel\u001b[38;5;241m.\u001b[39msimulator\u001b[38;5;241m.\u001b[39msimulate(mdp, steps\u001b[38;5;241m=\u001b[39msteps, seed\u001b[38;5;241m=\u001b[39mseed)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m# We could also provide a seed.\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m#partial_model = stormvogel.simulator.simulate(mdp, steps=steps, seed=seed)\u001b[39;00m\n\u001b[0;32m---> 12\u001b[0m vis \u001b[38;5;241m=\u001b[39m show(partial_model, save_and_embed\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, layout\u001b[38;5;241m=\u001b[39m\u001b[43mLayout\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlayouts/small_monty.json\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m)\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/layout.py:29\u001b[0m, in \u001b[0;36mLayout.__init__\u001b[0;34m(self, path, path_relative)\u001b[0m\n\u001b[1;32m 27\u001b[0m default_str \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mread()\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefault_dict: \u001b[38;5;28mdict\u001b[39m \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(default_str)\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpath_relative\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/layout.py:43\u001b[0m, in \u001b[0;36mLayout.load\u001b[0;34m(self, path, path_relative)\u001b[0m\n\u001b[1;32m 41\u001b[0m parsed_dict \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(parsed_str)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# Combine the parsed dict with default to fill missing keys as default values.\u001b[39;00m\n\u001b[0;32m---> 43\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayout: \u001b[38;5;28mdict\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mstormvogel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrdict\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmerge_dict\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 44\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdefault_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparsed_dict\u001b[49m\n\u001b[1;32m 45\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;66;03m# Load in schema for the dict_editor.\u001b[39;00m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(PACKAGE_ROOT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlayouts/schema.json\u001b[39m\u001b[38;5;124m\"\u001b[39m)) \u001b[38;5;28;01mas\u001b[39;00m f:\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/rdict.py:46\u001b[0m, in \u001b[0;36mmerge_dict\u001b[0;34m(dict1, dict2)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(val, \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m dict2 \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(dict2[key] \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mdict\u001b[39m):\n\u001b[0;32m---> 46\u001b[0m \u001b[43mmerge_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdict1\u001b[49m\u001b[43m[\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdict2\u001b[49m\u001b[43m[\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m dict2:\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/rdict.py:48\u001b[0m, in \u001b[0;36mmerge_dict\u001b[0;34m(dict1, dict2)\u001b[0m\n\u001b[1;32m 46\u001b[0m merge_dict(dict1[key], dict2[key])\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mkey\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdict2\u001b[49m:\n\u001b[1;32m 49\u001b[0m dict1[key] \u001b[38;5;241m=\u001b[39m dict2[key]\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, val \u001b[38;5;129;01min\u001b[39;00m dict2\u001b[38;5;241m.\u001b[39mitems():\n", - "\u001b[0;31mTypeError\u001b[0m: argument of type 'bool' is not iterable" - ] - } - ], - "source": [ - "# we can choose how many steps we take:\n", - "steps = 4\n", - "\n", - "# and we can specify a seed if we want:\n", - "seed = 12345676346\n", - "\n", - "# then we run the simulator:\n", - "partial_model = stormvogel.simulator.simulate(mdp, steps=steps, seed=seed)\n", - "# We could also provide a seed.\n", - "#partial_model = stormvogel.simulator.simulate(mdp, steps=steps, seed=seed)\n", - "\n", - "vis = show(partial_model, save_and_embed=True, layout=Layout(\"layouts/small_monty.json\"))" - ] - }, - { - "cell_type": "markdown", - "id": "49e3893d-bc35-4648-87eb-74a6a222ebf0", - "metadata": {}, - "source": [ - "We can also provide a scheduler (i.e. policy) which chooses what actions we should take at all time.\n", - "\n", - "In this case, we always take the first action, which means that we open door 0, and don't switch doors." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "59ac1e34-866c-42c4-b19b-c2a15c830e2e", - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "argument of type 'bool' is not iterable", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[5], line 8\u001b[0m\n\u001b[1;32m 5\u001b[0m scheduler \u001b[38;5;241m=\u001b[39m stormvogel\u001b[38;5;241m.\u001b[39mresult\u001b[38;5;241m.\u001b[39mScheduler(mdp, taken_actions)\n\u001b[1;32m 7\u001b[0m partial_model \u001b[38;5;241m=\u001b[39m stormvogel\u001b[38;5;241m.\u001b[39msimulator\u001b[38;5;241m.\u001b[39msimulate(mdp, steps\u001b[38;5;241m=\u001b[39msteps, scheduler\u001b[38;5;241m=\u001b[39mscheduler, seed\u001b[38;5;241m=\u001b[39mseed)\n\u001b[0;32m----> 8\u001b[0m vis \u001b[38;5;241m=\u001b[39m show(partial_model, save_and_embed\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, layout\u001b[38;5;241m=\u001b[39m\u001b[43mLayout\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlayouts/small_monty.json\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m)\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/layout.py:29\u001b[0m, in \u001b[0;36mLayout.__init__\u001b[0;34m(self, path, path_relative)\u001b[0m\n\u001b[1;32m 27\u001b[0m default_str \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mread()\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefault_dict: \u001b[38;5;28mdict\u001b[39m \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(default_str)\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpath_relative\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/layout.py:43\u001b[0m, in \u001b[0;36mLayout.load\u001b[0;34m(self, path, path_relative)\u001b[0m\n\u001b[1;32m 41\u001b[0m parsed_dict \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(parsed_str)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# Combine the parsed dict with default to fill missing keys as default values.\u001b[39;00m\n\u001b[0;32m---> 43\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayout: \u001b[38;5;28mdict\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mstormvogel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrdict\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmerge_dict\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 44\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdefault_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparsed_dict\u001b[49m\n\u001b[1;32m 45\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;66;03m# Load in schema for the dict_editor.\u001b[39;00m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(PACKAGE_ROOT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlayouts/schema.json\u001b[39m\u001b[38;5;124m\"\u001b[39m)) \u001b[38;5;28;01mas\u001b[39;00m f:\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/rdict.py:46\u001b[0m, in \u001b[0;36mmerge_dict\u001b[0;34m(dict1, dict2)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(val, \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m dict2 \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(dict2[key] \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mdict\u001b[39m):\n\u001b[0;32m---> 46\u001b[0m \u001b[43mmerge_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdict1\u001b[49m\u001b[43m[\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdict2\u001b[49m\u001b[43m[\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m dict2:\n", - "File \u001b[0;32m~/git/stormvogel/env/lib/python3.12/site-packages/stormvogel/rdict.py:48\u001b[0m, in \u001b[0;36mmerge_dict\u001b[0;34m(dict1, dict2)\u001b[0m\n\u001b[1;32m 46\u001b[0m merge_dict(dict1[key], dict2[key])\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mkey\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdict2\u001b[49m:\n\u001b[1;32m 49\u001b[0m dict1[key] \u001b[38;5;241m=\u001b[39m dict2[key]\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, val \u001b[38;5;129;01min\u001b[39;00m dict2\u001b[38;5;241m.\u001b[39mitems():\n", - "\u001b[0;31mTypeError\u001b[0m: argument of type 'bool' is not iterable" - ] + "data": { + "text/html": [ + "\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "vis = show(mdp, layout=Layout(\"layouts/monty.json\"), save_and_embed=True)" + ] + }, + { + "cell_type": "markdown", + "id": "b5b2990c-65ed-4d7b-a4b8-f303843622e5", + "metadata": {}, + "source": [ + "We want to simulate this model. That is, we start at the initial state and then we walk through the model by choosing random actions.\n", + "\n", + "When we do this, we get a partial model as a result that contains everything we discovered during this walk. \n", + "\n", + "Try running this multiple times, and observe that sometimes we get to the target and sometimes we do not." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "eb0fadc0-7bb6-4c1d-ae3e-9e16527726ab", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# we can choose how many steps we take:\n", + "steps = 4\n", + "\n", + "# and we can specify a seed if we want:\n", + "seed = 12345676346\n", + "\n", + "# then we run the simulator:\n", + "partial_model = stormvogel.simulator.simulate(mdp, steps=steps, seed=seed)\n", + "# We could also provide a seed.\n", + "#partial_model = stormvogel.simulator.simulate(mdp, steps=steps, seed=seed)\n", + "\n", + "vis = show(partial_model, save_and_embed=True, layout=Layout(\"layouts/small_monty.json\"))" + ] + }, + { + "cell_type": "markdown", + "id": "49e3893d-bc35-4648-87eb-74a6a222ebf0", + "metadata": {}, + "source": [ + "We can also provide a scheduler (i.e. policy) which chooses what actions we should take at all time.\n", + "\n", + "In this case, we always take the first action, which means that we open door 0, and don't switch doors." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "59ac1e34-866c-42c4-b19b-c2a15c830e2e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -241,7 +1378,7 @@ "text/html": [ "\n", " 0\n", + " else []\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "9b7c4637-a48e-4841-bbf4-e2da20e37aee", + "metadata": {}, + "source": [ + "we can also optionally provide functions that assign rewards and labels" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8531a632-23ae-4451-b5cc-ede2d16ec9ee", + "metadata": {}, + "outputs": [], + "source": [ + "def rewards(s: pgc.State, a: pgc.Action):\n", + " return [1, 2]\n", + "\n", + "def labels(s: pgc.State):\n", + " return [str(s.x)]" + ] + }, + { + "cell_type": "markdown", + "id": "bd5ccbfd-003b-417e-b669-233797084c80", + "metadata": {}, + "source": [ + "We then combine all of the above to call the build_pgc function that will build our model using the functions" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7a920d7d-1430-43b5-b56e-3e74a7749ffe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ModelType.MDP with name None\n", + "\n", + "States:\n", + "State 0 with labels ['init', '5'] and features {}\n", + "State 1 with labels ['6'] and features {}\n", + "State 2 with labels ['4'] and features {}\n", + "State 3 with labels ['7'] and features {}\n", + "State 4 with labels ['3'] and features {}\n", + "State 5 with labels ['8'] and features {}\n", + "State 6 with labels ['2'] and features {}\n", + "State 7 with labels ['9'] and features {}\n", + "State 8 with labels ['1'] and features {}\n", + "State 9 with labels ['10'] and features {}\n", + "State 10 with labels ['0'] and features {}\n", + "\n", + "Transitions:\n", + "Action with labels frozenset({'left'}) => 0.5 -> State 1 with labels ['6'] and features {}, 0.5 -> State 0 with labels ['init', '5'] and features {}; Action with labels frozenset({'right'}) => 0.5 -> State 2 with labels ['4'] and features {}, 0.5 -> State 0 with labels ['init', '5'] and features {}\n", + "Action with labels frozenset({'left'}) => 0.5 -> State 3 with labels ['7'] and features {}, 0.5 -> State 1 with labels ['6'] and features {}; Action with labels frozenset({'right'}) => 0.5 -> State 0 with labels ['init', '5'] and features {}, 0.5 -> State 1 with labels ['6'] and features {}\n", + "Action with labels frozenset({'left'}) => 0.5 -> State 0 with labels ['init', '5'] and features {}, 0.5 -> State 2 with labels ['4'] and features {}; Action with labels frozenset({'right'}) => 0.5 -> State 4 with labels ['3'] and features {}, 0.5 -> State 2 with labels ['4'] and features {}\n", + "Action with labels frozenset({'left'}) => 0.5 -> State 5 with labels ['8'] and features {}, 0.5 -> State 3 with labels ['7'] and features {}; Action with labels frozenset({'right'}) => 0.5 -> State 1 with labels ['6'] and features {}, 0.5 -> State 3 with labels ['7'] and features {}\n", + "Action with labels frozenset({'left'}) => 0.5 -> State 2 with labels ['4'] and features {}, 0.5 -> State 4 with labels ['3'] and features {}; Action with labels frozenset({'right'}) => 0.5 -> State 6 with labels ['2'] and features {}, 0.5 -> State 4 with labels ['3'] and features {}\n", + "Action with labels frozenset({'left'}) => 0.5 -> State 7 with labels ['9'] and features {}, 0.5 -> State 5 with labels ['8'] and features {}; Action with labels frozenset({'right'}) => 0.5 -> State 3 with labels ['7'] and features {}, 0.5 -> State 5 with labels ['8'] and features {}\n", + "Action with labels frozenset({'left'}) => 0.5 -> State 4 with labels ['3'] and features {}, 0.5 -> State 6 with labels ['2'] and features {}; Action with labels frozenset({'right'}) => 0.5 -> State 8 with labels ['1'] and features {}, 0.5 -> State 6 with labels ['2'] and features {}\n", + "Action with labels frozenset({'left'}) => 0.5 -> State 9 with labels ['10'] and features {}, 0.5 -> State 7 with labels ['9'] and features {}; Action with labels frozenset({'right'}) => 0.5 -> State 5 with labels ['8'] and features {}, 0.5 -> State 7 with labels ['9'] and features {}\n", + "Action with labels frozenset({'left'}) => 0.5 -> State 6 with labels ['2'] and features {}, 0.5 -> State 8 with labels ['1'] and features {}; Action with labels frozenset({'right'}) => 0.5 -> State 10 with labels ['0'] and features {}, 0.5 -> State 8 with labels ['1'] and features {}\n", + "Action with labels frozenset({'right'}) => 0.5 -> State 7 with labels ['9'] and features {}, 0.5 -> State 9 with labels ['10'] and features {}\n", + "Action with labels frozenset({'left'}) => 0.5 -> State 8 with labels ['1'] and features {}, 0.5 -> State 10 with labels ['0'] and features {}\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "feebb5c54aac47279bc1982b982921e9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "093c11bd82784020b34d5bb1e0c252e5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "ename": "RuntimeError", + "evalue": "This action is not available in this state", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 10\u001b[0m\n\u001b[1;32m 1\u001b[0m pgc_model \u001b[38;5;241m=\u001b[39m pgc\u001b[38;5;241m.\u001b[39mbuild_pgc(\n\u001b[1;32m 2\u001b[0m delta\u001b[38;5;241m=\u001b[39mdelta,\n\u001b[1;32m 3\u001b[0m available_actions\u001b[38;5;241m=\u001b[39mavailable_actions,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 6\u001b[0m rewards\u001b[38;5;241m=\u001b[39mrewards,\n\u001b[1;32m 7\u001b[0m )\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(pgc_model)\n\u001b[0;32m---> 10\u001b[0m \u001b[43mshow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpgc_model\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/repositories/stormvogel/.venv/lib/python3.12/site-packages/stormvogel/show.py:53\u001b[0m, in \u001b[0;36mshow\u001b[0;34m(model, result, scheduler, name, layout, show_editor, separate_labels, debug_output)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;66;03m# do_display = not show_editor\u001b[39;00m\n\u001b[1;32m 42\u001b[0m vis \u001b[38;5;241m=\u001b[39m stormvogel\u001b[38;5;241m.\u001b[39mvisualization\u001b[38;5;241m.\u001b[39mVisualization(\n\u001b[1;32m 43\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m 44\u001b[0m name\u001b[38;5;241m=\u001b[39mname,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 51\u001b[0m do_init_server\u001b[38;5;241m=\u001b[39mdo_init_server,\n\u001b[1;32m 52\u001b[0m )\n\u001b[0;32m---> 53\u001b[0m \u001b[43mvis\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshow\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m show_editor:\n\u001b[1;32m 55\u001b[0m e \u001b[38;5;241m=\u001b[39m stormvogel\u001b[38;5;241m.\u001b[39mlayout_editor\u001b[38;5;241m.\u001b[39mLayoutEditor(\n\u001b[1;32m 56\u001b[0m layout, vis, do_display\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, debug_output\u001b[38;5;241m=\u001b[39mdebug_output\n\u001b[1;32m 57\u001b[0m )\n", + "File \u001b[0;32m~/repositories/stormvogel/.venv/lib/python3.12/site-packages/stormvogel/visualization.py:110\u001b[0m, in \u001b[0;36mVisualization.show\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnt\u001b[38;5;241m.\u001b[39menable_exploration_mode(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mget_initial_state()\u001b[38;5;241m.\u001b[39mid)\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayout\u001b[38;5;241m.\u001b[39mset_groups(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mseparate_labels)\n\u001b[0;32m--> 110\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__add_states\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__add_transitions()\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__update_physics_enabled()\n", + "File \u001b[0;32m~/repositories/stormvogel/.venv/lib/python3.12/site-packages/stormvogel/visualization.py:131\u001b[0m, in \u001b[0;36mVisualization.__add_states\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 128\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__format_result(state)\n\u001b[1;32m 129\u001b[0m observations \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__format_observations(state)\n\u001b[0;32m--> 131\u001b[0m rewards \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__format_rewards\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstormvogel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mEmptyAction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 133\u001b[0m group \u001b[38;5;241m=\u001b[39m ( \u001b[38;5;66;03m# Use a non-default group if specified.\u001b[39;00m\n\u001b[1;32m 134\u001b[0m und(state\u001b[38;5;241m.\u001b[39mlabels[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstates\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 140\u001b[0m )\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnt\u001b[38;5;241m.\u001b[39madd_node(\n\u001b[1;32m 143\u001b[0m state\u001b[38;5;241m.\u001b[39mid,\n\u001b[1;32m 144\u001b[0m label\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m,\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(state\u001b[38;5;241m.\u001b[39mlabels) \u001b[38;5;241m+\u001b[39m rewards \u001b[38;5;241m+\u001b[39m res \u001b[38;5;241m+\u001b[39m observations,\n\u001b[1;32m 145\u001b[0m group\u001b[38;5;241m=\u001b[39mgroup,\n\u001b[1;32m 146\u001b[0m )\n", + "File \u001b[0;32m~/repositories/stormvogel/.venv/lib/python3.12/site-packages/stormvogel/visualization.py:226\u001b[0m, in \u001b[0;36mVisualization.__format_rewards\u001b[0;34m(self, s, a)\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m reward_model \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mrewards:\n\u001b[1;32m 225\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39msupports_actions():\n\u001b[0;32m--> 226\u001b[0m reward \u001b[38;5;241m=\u001b[39m \u001b[43mreward_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_state_action_reward\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ma\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 228\u001b[0m reward \u001b[38;5;241m=\u001b[39m reward_model\u001b[38;5;241m.\u001b[39mget_state_reward(s)\n", + "File \u001b[0;32m~/repositories/stormvogel/.venv/lib/python3.12/site-packages/stormvogel/model.py:392\u001b[0m, in \u001b[0;36mRewardModel.get_state_action_reward\u001b[0;34m(self, state, action)\u001b[0m\n\u001b[1;32m 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 391\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 392\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThis action is not available in this state\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 393\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 394\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 395\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe model this rewardmodel belongs to does not support actions\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 396\u001b[0m )\n", + "\u001b[0;31mRuntimeError\u001b[0m: This action is not available in this state" + ] + } + ], + "source": [ + "pgc_model = pgc.build_pgc(\n", + " delta=delta,\n", + " available_actions=available_actions,\n", + " initial_state_pgc=initial_state,\n", + " labels=labels,\n", + " rewards=rewards,\n", + ")\n", + "\n", + "print(pgc_model)\n", + "show(pgc_model)" + ] + }, + { + "cell_type": "markdown", + "id": "65fa94ec-4102-46e8-b98b-68931802514d", + "metadata": {}, + "source": [ + "We don't have to use the provided State class, in fact we can use any object we like! Here is an example where we use integers instead of states objects." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "be844a3f-bbc8-4c8d-a351-7627f7dd33b9", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d12c8a7fd7494d93a1736369a771ec0b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7db435fff12148f487ddec42f598e650", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1e103d3ed4184babb623054e3a710670", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(Output(), Output()))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def delta(state):\n", + " return [\n", + " (0.5, (state + 1) % 5),\n", + " (0.5, (state - 1) % 5),\n", + " ]\n", + " \n", + "def rewards(state):\n", + " return [state]\n", + "\n", + "pgc_model = pgc.build_pgc(delta, initial_state_pgc=0, rewards=rewards, modeltype=model.ModelType.DTMC)\n", + "show(pgc_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d45f2907-f51b-4177-9900-d330c495eace", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14e107b0-b890-49ca-ad52-a4289d1776b7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/getting_started/model.html b/docs/getting_started/model.html index 3c41ae2..79f8c00 100644 --- a/docs/getting_started/model.html +++ b/docs/getting_started/model.html @@ -1,8 +1,8 @@