Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix_stopping_game #365

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import patch, MagicMock
from gymnasium.spaces import Box, Discrete
import numpy as np
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState
Expand All @@ -23,19 +24,19 @@ def setup_env(self) -> None:
:return: None
"""
env_name = "test_env"
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
O = np.array([0, 1])
Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
T = StoppingGameUtil.transition_tensor(L=3, p=0)
O = StoppingGameUtil.observation_space(n=100)
Z = StoppingGameUtil.observation_tensor(n=100)
R = np.zeros((2, 3, 3, 3))
S = np.array([0, 1, 2])
A1 = np.array([0, 1, 2])
A2 = np.array([0, 1, 2])
S = StoppingGameUtil.state_space()
A1 = StoppingGameUtil.defender_actions()
A2 = StoppingGameUtil.attacker_actions()
L = 2
R_INT = 1
R_COST = 2
R_SLA = 3
R_ST = 4
b1 = np.array([0.6, 0.4])
b1 = StoppingGameUtil.b1()
save_dir = "save_directory"
checkpoint_traces_freq = 100
gamma = 0.9
Expand Down Expand Up @@ -69,12 +70,12 @@ def test_stopping_game_init_(self) -> None:

:return: None
"""
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
O = np.array([0, 1])
A1 = np.array([0, 1, 2])
A2 = np.array([0, 1, 2])
T = StoppingGameUtil.transition_tensor(L=3, p=0)
O = StoppingGameUtil.observation_space(n=100)
A1 = StoppingGameUtil.defender_actions()
A2 = StoppingGameUtil.attacker_actions()
L = 2
b1 = np.array([0.6, 0.4])
b1 = StoppingGameUtil.b1()
attacker_observation_space = Box(
low=np.array([0.0, 0.0, 0.0]),
high=np.array([float(L), 1.0, 2.0]),
Expand Down Expand Up @@ -304,7 +305,7 @@ def test_is_state_terminal(self) -> None:
assert not env.is_state_terminal(state_tuple)

with pytest.raises(ValueError):
env.is_state_terminal([1, 2, 3]) # type: ignore
env.is_state_terminal([1, 2, 3]) # type: ignore

def test_get_observation_from_history(self) -> None:
"""
Expand Down Expand Up @@ -346,26 +347,6 @@ def test_step(self) -> None:
:return: None
"""
env = StoppingGameEnv(self.config)
env.state = MagicMock()
env.state.s = 1
env.state.l = 2
env.state.t = 0
env.state.attacker_observation.return_value = np.array([1, 2, 3])
env.state.defender_observation.return_value = np.array([4, 5, 6])
env.state.b = np.array([0.5, 0.5, 0.0])

env.trace = MagicMock()
env.trace.defender_rewards = []
env.trace.attacker_rewards = []
env.trace.attacker_actions = []
env.trace.defender_actions = []
env.trace.infos = []
env.trace.states = []
env.trace.beliefs = []
env.trace.infrastructure_metrics = []
env.trace.attacker_observations = []
env.trace.defender_observations = []

with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_state",
return_value=2):
with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_observation",
Expand All @@ -376,32 +357,20 @@ def test_step(self) -> None:
1,
(
np.array(
[[0.2, 0.8, 0.0], [0.6, 0.4, 0.0], [0.5, 0.5, 0.0]]
[[0.2, 0.8], [0.6, 0.4], [0.5, 0.5]]
),
2,
),
)
observations, rewards, terminated, truncated, info = env.step(
action_profile
)

assert (observations[0] == np.array([4, 5, 6])).all(), "Incorrect defender observations"
assert (observations[1] == np.array([1, 2, 3])).all(), "Incorrect attacker observations"
assert observations[0].all() == np.array([1, 0.7]).all(), "Incorrect defender observations"
assert observations[1].all() == np.array([1, 2, 3]).all(), "Incorrect attacker observations"
assert rewards == (0, 0)
assert not terminated
assert not truncated
assert env.trace.defender_rewards[-1] == 0
assert env.trace.attacker_rewards[-1] == 0
assert env.trace.attacker_actions[-1] == 2
assert env.trace.defender_actions[-1] == 1
assert env.trace.infos[-1] == info
assert env.trace.states[-1] == 2
print(env.trace.beliefs)
assert env.trace.beliefs[-1] == 0.7
assert env.trace.infrastructure_metrics[-1] == 1
assert (env.trace.attacker_observations[-1] == np.array([1, 2, 3])).all()
assert (env.trace.defender_observations[-1] == np.array([4, 5, 6])).all()


def test_info(self) -> None:
"""
Tests the function of adding the cumulative reward and episode length to the info dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from gym_csle_stopping_game.dao.stopping_game_attacker_mdp_config import (
StoppingGameAttackerMdpConfig,
)
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
from csle_common.dao.training.policy import Policy
from csle_common.dao.training.random_policy import RandomPolicy
from csle_common.dao.training.player_type import PlayerType
from csle_common.dao.simulation_config.action import Action
import pytest
from unittest.mock import MagicMock
import numpy as np
Expand All @@ -25,19 +29,19 @@ def setup_env(self) -> None:
:return: None
"""
env_name = "test_env"
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
O = np.array([0, 1])
Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
T = StoppingGameUtil.transition_tensor(L=3, p=0)
O = StoppingGameUtil.observation_space(n=100)
Z = StoppingGameUtil.observation_tensor(n=100)
R = np.zeros((2, 3, 3, 3))
S = np.array([0, 1, 2])
A1 = np.array([0, 1, 2])
A2 = np.array([0, 1, 2])
S = StoppingGameUtil.state_space()
A1 = StoppingGameUtil.defender_actions()
A2 = StoppingGameUtil.attacker_actions()
L = 2
R_INT = 1
R_COST = 2
R_SLA = 3
R_ST = 4
b1 = np.array([0.6, 0.4])
b1 = StoppingGameUtil.b1()
save_dir = "save_directory"
checkpoint_traces_freq = 100
gamma = 0.9
Expand Down Expand Up @@ -107,9 +111,8 @@ def test_reset(self) -> None:
)

env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
attacker_obs, info = env.reset()
assert env.latest_defender_obs.all() == np.array([2, 0.4]).all() # type: ignore
assert info == {}
info = env.reset()
assert info[-1] == {}

def test_set_model(self) -> None:
"""
Expand Down Expand Up @@ -144,7 +147,7 @@ def test_set_state(self) -> None:
)

env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
assert not env.set_state(1) # type: ignore
assert not env.set_state(1) # type: ignore

def test_calculate_stage_policy(self) -> None:
"""
Expand Down Expand Up @@ -190,7 +193,7 @@ def test_get_attacker_dist(self) -> None:
def test_render(self) -> None:
"""
Tests the function for rendering the environment

:return: None
"""
defender_strategy = MagicMock(spec=Policy)
Expand Down Expand Up @@ -317,7 +320,7 @@ def test_get_actions_from_particles(self) -> None:
particles = [1, 2, 3]
t = 0
observation = 0
expected_actions = [0, 1, 2]
expected_actions = [0, 1]
assert (
env.get_actions_from_particles(particles, t, observation)
== expected_actions
Expand All @@ -326,18 +329,32 @@ def test_get_actions_from_particles(self) -> None:
def test_step(self) -> None:
"""
Tests the function for taking a step in the environment by executing the given action

:return: None
"""
defender_strategy = MagicMock(spec=Policy)
defender_stage_strategy = np.zeros((3, 2))
defender_stage_strategy[0][0] = 0.9
defender_stage_strategy[0][1] = 0.1
defender_stage_strategy[1][0] = 0.9
defender_stage_strategy[1][1] = 0.1
defender_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A1))
defender_strategy = RandomPolicy(
actions=defender_actions,
player_type=PlayerType.DEFENDER,
stage_policy_tensor=list(defender_stage_strategy),
)
attacker_mdp_config = StoppingGameAttackerMdpConfig(
env_name="test_env",
stopping_game_config=self.config,
defender_strategy=defender_strategy,
stopping_game_name="csle-stopping-game-v1",
)

env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
pi2 = np.array([[0.5, 0.5]])
with pytest.raises(AssertionError):
env.step(pi2)
env.reset()
pi2 = env.calculate_stage_policy(o=list(env.latest_attacker_obs), a2=0) # type: ignore
attacker_obs, reward, terminated, truncated, info = env.step(pi2)
assert isinstance(attacker_obs[0], float) # type: ignore
assert isinstance(terminated, bool) # type: ignore
assert isinstance(truncated, bool) # type: ignore
assert isinstance(reward, float) # type: ignore
assert isinstance(info, dict) # type: ignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import StoppingGamePomdpDefenderEnv
from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import (
StoppingGamePomdpDefenderEnv,
)
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import StoppingGameDefenderPomdpConfig
from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import (
StoppingGameDefenderPomdpConfig,
)
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
from csle_common.dao.training.policy import Policy
from csle_common.dao.simulation_config.action import Action
from csle_common.dao.training.random_policy import RandomPolicy
from csle_common.dao.training.player_type import PlayerType
import pytest
Expand Down Expand Up @@ -219,7 +224,7 @@ def test_set_state(self) -> None:
stopping_game_name="csle-stopping-game-v1",
)
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
assert env.set_state(1) is None # type: ignore
assert env.set_state(1) is None # type: ignore

def test_get_observation_from_history(self) -> None:
"""
Expand Down Expand Up @@ -301,7 +306,10 @@ def test_get_actions_from_particles(self) -> None:
t = 0
observation = 0
expected_actions = [0, 1]
assert env.get_actions_from_particles(particles, t, observation) == expected_actions
assert (
env.get_actions_from_particles(particles, t, observation)
== expected_actions
)

def test_step(self) -> None:
"""
Expand All @@ -315,8 +323,12 @@ def test_step(self) -> None:
attacker_stage_strategy[1][0] = 0.9
attacker_stage_strategy[1][1] = 0.1
attacker_stage_strategy[2] = attacker_stage_strategy[1]
attacker_strategy = RandomPolicy(actions=list(self.config.A2), player_type=PlayerType.ATTACKER,
stage_policy_tensor=list(attacker_stage_strategy))
attacker_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A2))
attacker_strategy = RandomPolicy(
actions=attacker_actions,
player_type=PlayerType.ATTACKER,
stage_policy_tensor=list(attacker_stage_strategy),
)
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
env_name="test_env",
stopping_game_config=self.config,
Expand All @@ -328,9 +340,9 @@ def test_step(self) -> None:
env.reset()
defender_obs, reward, terminated, truncated, info = env.step(a1)
assert len(defender_obs) == 2
assert isinstance(defender_obs[0], float) # type: ignore
assert isinstance(defender_obs[1], float) # type: ignore
assert isinstance(reward, float) # type: ignore
assert isinstance(terminated, bool) # type: ignore
assert isinstance(truncated, bool) # type: ignore
assert isinstance(info, dict) # type: ignore
assert isinstance(defender_obs[0], float) # type: ignore
assert isinstance(defender_obs[1], float) # type: ignore
assert isinstance(reward, float) # type: ignore
assert isinstance(terminated, bool) # type: ignore
assert isinstance(truncated, bool) # type: ignore
assert isinstance(info, dict) # type: ignore