-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Open ai gym wrapper + bump engine version (#52)
* wip * Update gym.py * . * wip * wip * Update Dockerfile.gym.dev * wip * Update gym.py * fwd model close * Update dev_gym.py * Update forward_model.py * wip * Update agent.py * Update gym.py * Update gym.py * Update gym.py * . * . * Update gym.py * wip * wip * Update gym.py * wip * wip * Update dev_gym.py * wip * wip * wip * wip * Update gym.py * wip * Update gym.py * Update gym.py * Update gym.py * . * Bump websockets version * Update gym.py * Update gym.py * wip * wip * gym * close * Update README.md * Update README.md * Update README.md
- Loading branch information
Showing
12 changed files
with
193 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
version: "3" | ||
services: | ||
gym: | ||
extends: | ||
file: base-compose.yml | ||
service: python3-gym-dev | ||
environment: | ||
- FWD_MODEL_CONNECTION_STRING=ws://fwd-server:6969/?role=admin | ||
depends_on: | ||
- fwd-server | ||
networks: | ||
- coderone-open-ai-gym-wrapper | ||
|
||
fwd-server: | ||
extends: | ||
file: base-compose.yml | ||
service: game-server | ||
environment: | ||
- TELEMETRY_ENABLED=0 | ||
- PORT=6969 | ||
- WORLD_SEED=1234 | ||
- PRNG_SEED=1234 | ||
networks: | ||
- coderone-open-ai-gym-wrapper | ||
networks: | ||
coderone-open-ai-gym-wrapper: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
FROM python:3.8-bullseye | ||
|
||
COPY ./requirements.txt /app/requirements.txt | ||
WORKDIR /app | ||
RUN python -m pip install -r requirements.txt | ||
ENTRYPOINT PYTHONUNBUFFERED=1 python dev_gym.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Overview | ||
|
||
`agent.py` - random agent | ||
|
||
`agent_fwd.py` - random agent that connects to forward model | ||
|
||
`dev_gym.py` - [open ai gym wrapper](https://gym.openai.com/) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import asyncio | ||
from typing import Dict | ||
from gym import Gym | ||
import os | ||
|
||
fwd_model_uri = os.environ.get( | ||
"FWD_MODEL_CONNECTION_STRING") or "ws://127.0.0.1:6969/?role=admin" | ||
|
||
mock_6x6_state: Dict = {"agents": {"a": {"agent_id": "a", "unit_ids": ["c", "e", "g"]}, "b": {"agent_id": "b", "unit_ids": ["d", "f", "h"]}}, "unit_state": {"c": {"coordinates": [0, 1], "hp": 3, "inventory": {"bombs": 3}, "blast_diameter": 3, "unit_id": "c", "agent_id": "a", "invulnerability": 0}, "d": {"coordinates": [5, 1], "hp": 3, "inventory": {"bombs": 3}, "blast_diameter": 3, "unit_id": "d", "agent_id": "b", "invulnerability": 0}, "e": {"coordinates": [3, 3], "hp": 3, "inventory": {"bombs": 3}, "blast_diameter": 3, "unit_id": "e", "agent_id": "a", "invulnerability": 0}, "f": {"coordinates": [2, 3], "hp": 3, "inventory": {"bombs": 3}, "blast_diameter": 3, "unit_id": "f", "agent_id": "b", "invulnerability": 0}, "g": {"coordinates": [2, 4], "hp": 3, "inventory": {"bombs": 3}, "blast_diameter": 3, "unit_id": "g", "agent_id": "a", "invulnerability": 0}, "h": {"coordinates": [3, 4], "hp": 3, "inventory": {"bombs": 3}, "blast_diameter": 3, "unit_id": "h", "agent_id": "b", "invulnerability": 0}}, "entities": [ | ||
{"created": 0, "x": 0, "y": 3, "type": "m"}, {"created": 0, "x": 5, "y": 3, "type": "m"}, {"created": 0, "x": 4, "y": 3, "type": "m"}, {"created": 0, "x": 1, "y": 3, "type": "m"}, {"created": 0, "x": 3, "y": 5, "type": "m"}, {"created": 0, "x": 2, "y": 5, "type": "m"}, {"created": 0, "x": 5, "y": 4, "type": "m"}, {"created": 0, "x": 0, "y": 4, "type": "m"}, {"created": 0, "x": 1, "y": 1, "type": "w", "hp": 1}, {"created": 0, "x": 4, "y": 1, "type": "w", "hp": 1}, {"created": 0, "x": 3, "y": 0, "type": "w", "hp": 1}, {"created": 0, "x": 2, "y": 0, "type": "w", "hp": 1}, {"created": 0, "x": 5, "y": 5, "type": "w", "hp": 1}, {"created": 0, "x": 0, "y": 5, "type": "w", "hp": 1}, {"created": 0, "x": 4, "y": 0, "type": "w", "hp": 1}, {"created": 0, "x": 1, "y": 0, "type": "w", "hp": 1}, {"created": 0, "x": 5, "y": 0, "type": "w", "hp": 1}, {"created": 0, "x": 0, "y": 0, "type": "w", "hp": 1}], "world": {"width": 6, "height": 6}, "tick": 0, "config": {"tick_rate_hz": 10, "game_duration_ticks": 300, "fire_spawn_interval_ticks": 2}} | ||
|
||
|
||
def calculate_reward(state: Dict): | ||
# custom reward function | ||
return 1 | ||
|
||
|
||
async def main(): | ||
gym = Gym(fwd_model_uri) | ||
await gym.connect() | ||
env = gym.make("bomberland-open-ai-gym", mock_6x6_state) | ||
for i_ in range(1000): | ||
actions = [] | ||
observation, done, info = await env.step(actions) | ||
reward = calculate_reward(observation) | ||
|
||
print(f"reward: {reward} done: {done} info: {info}") | ||
if done: | ||
await env.reset() | ||
await gym.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import asyncio | ||
import json | ||
from typing import Callable, Dict, List | ||
|
||
import websockets | ||
from forward_model import ForwardModel | ||
|
||
|
||
class GymEnv(): | ||
def __init__(self, fwd_model: ForwardModel, channel: int, initial_state: Dict, send_next_state: Callable[[Dict, List[Dict], int], Dict]): | ||
self._state = initial_state | ||
self._initial_state = initial_state | ||
self._fwd = fwd_model | ||
self._channel = channel | ||
self._send = send_next_state | ||
|
||
async def reset(self): | ||
self._state = self._initial_state | ||
print("Resetting") | ||
|
||
async def step(self, actions): | ||
state = await self._send(self._state, actions, self._channel) | ||
self._state = state.get("next_state") | ||
return [state.get("next_state"), state.get("is_complete"), state.get("tick_result").get("events")] | ||
|
||
|
||
class Gym(): | ||
def __init__(self, fwd_model_uri: str): | ||
self._client_fwd = ForwardModel(fwd_model_uri) | ||
self._channel_counter = 0 | ||
self._channel_is_busy_status: Dict[int, bool] = {} | ||
self._channel_buffer: Dict[int, Dict] = {} | ||
self._client_fwd.set_next_state_callback(self._on_next_game_state) | ||
self._environments: Dict[str, GymEnv] = {} | ||
|
||
async def connect(self): | ||
loop = asyncio.get_event_loop() | ||
|
||
client_fwd_connection = await self._client_fwd.connect() | ||
|
||
loop = asyncio.get_event_loop() | ||
loop.create_task( | ||
self._client_fwd._handle_messages(client_fwd_connection)) | ||
|
||
async def close(self): | ||
await self._client_fwd.close() | ||
|
||
async def _on_next_game_state(self, state): | ||
channel = state.get("sequence_id") | ||
self._channel_is_busy_status[channel] = False | ||
self._channel_buffer[channel] = state | ||
|
||
def make(self, name: str, initial_state: Dict) -> GymEnv: | ||
if self._environments.get(name) is not None: | ||
raise Exception( | ||
f"environment \"{name}\" has already been instantiated") | ||
self._environments[name] = GymEnv( | ||
self._client_fwd, self._channel_counter, initial_state, self._send_next_state) | ||
self._channel_counter += 1 | ||
return self._environments[name] | ||
|
||
async def _send_next_state(self, state, actions, channel: int): | ||
self._channel_is_busy_status[channel] = True | ||
await self._client_fwd.send_next_state(channel, state, actions) | ||
while self._channel_is_busy_status[channel] == True: | ||
# TODO figure out why packets are not received without some sleep | ||
await asyncio.sleep(0.0001) | ||
result = self._channel_buffer[channel] | ||
del self._channel_buffer[channel] | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
asyncio==3.4.3 | ||
websockets==8.1 | ||
websockets==10.1 |