Skip to content

Commit

Permalink
can train tictactoe
Browse files Browse the repository at this point in the history
can train tictactoe
  • Loading branch information
huangshiyu13 authored Jul 21, 2023
2 parents b7777bf + 7cc9f8a commit 9e30d1f
Show file tree
Hide file tree
Showing 16 changed files with 803 additions and 132 deletions.
13 changes: 11 additions & 2 deletions examples/selfplay/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
## How to Use

Users can train selfplay strategy in connect3 via:
Users can train Tic-Tac-Toe via:

```shell
python train_selfplay.py --config selfplay_connect3.yaml
python train_selfplay.py
```


## Play with a trained agent

Users can play with a trained agent via:

```shell
python human_vs_agent.py
```
73 changes: 73 additions & 0 deletions examples/selfplay/human_vs_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import numpy as np
import torch
from tictactoe_render import TictactoeRender

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.envs.wrappers import FlattenObservation
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
from openrl.selfplay.wrappers.human_opponent_wrapper import HumanOpponentWrapper
from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper


def get_fake_env(env_num):
env = make(
"tictactoe_v3",
env_num=env_num,
asynchronous=True,
opponent_wrappers=[RandomOpponentWrapper],
env_wrappers=[FlattenObservation],
auto_reset=False,
)
return env


def get_human_env(env_num):
env = make(
"tictactoe_v3",
env_num=env_num,
asynchronous=True,
opponent_wrappers=[TictactoeRender, HumanOpponentWrapper],
env_wrappers=[FlattenObservation],
auto_reset=False,
)
return env


def human_vs_agent():
env_num = 1
fake_env = get_fake_env(env_num)
env = get_human_env(env_num)
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args()
net = Net(fake_env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")

agent = Agent(net)

agent.load("./ppo_agent/")

total_reward = 0.0
ep_num = 5
for ep_now in range(ep_num):
agent.set_env(fake_env)
obs, info = env.reset()

done = False
step = 0

while not np.any(done):
# predict next action based on the observation
action, _ = agent.act(obs, info, deterministic=True)
obs, r, done, info = env.step(action)
step += 1

if np.any(done):
total_reward += np.mean(r) > 0
print(f"{ep_now}/{ep_num}: reward: {np.mean(r)}")
print(f"win rate: {total_reward / ep_num}")
env.close()


if __name__ == "__main__":
human_vs_agent()
59 changes: 59 additions & 0 deletions examples/selfplay/test_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import numpy as np
from tictactoe_render import TictactoeRender

from openrl.envs.common import make
from openrl.envs.wrappers import FlattenObservation
from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper


def test_env():
env_num = 1
render_model = None
render_model = "human"
env = make(
"tictactoe_v3",
render_mode=render_model,
env_num=env_num,
asynchronous=False,
opponent_wrappers=[TictactoeRender, RandomOpponentWrapper],
env_wrappers=[FlattenObservation],
)

obs, info = env.reset(seed=1)
done = False
step_num = 0
while not done:
action = env.random_action(info)

obs, done, r, info = env.step(action)

done = np.any(done)
step_num += 1
if done:
print(
"step:"
f" {step_num},{[env_info['final_observation'] for env_info in info]}"
)
else:
print(f"step: {step_num},{obs}")


if __name__ == "__main__":
test_env()
70 changes: 70 additions & 0 deletions examples/selfplay/tictactoe_render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import time
from typing import Optional, Union

import pygame
from pettingzoo.utils.env import ActionType, AECEnv, ObsType
from pettingzoo.utils.wrappers.base import BaseWrapper
from tictactoe_utils.game import Game


class TictactoeRender(BaseWrapper):
def __init__(self, env: AECEnv):
super().__init__(env)

self.game = Game()
self.last_action = None
self.last_length = 0
self.render_mode = "game"

def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed, options)
if self.render_mode == "game":
self.game.reset()
pygame.display.update()
time.sleep(0.3)

self.last_action = None

def step(self, action: ActionType) -> None:
result = super().step(action)
self.last_action = action
return result

def observe(self, agent: str) -> Optional[ObsType]:
obs = super().observe(agent)
if self.last_action is not None:
if self.render_mode == "game":
self.game.make_move(self.last_action // 3, self.last_action % 3)
pygame.display.update()
self.last_action = None
time.sleep(0.3)
return obs

def close(self):
self.game.close()
super().close()

def set_render_mode(self, render_mode: Union[None, str]):
self.render_mode = render_mode

def get_human_action(self, agent, observation, termination, truncation, info):
return self.game.get_human_action(
agent, observation, termination, truncation, info
)
17 changes: 17 additions & 0 deletions examples/selfplay/tictactoe_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
137 changes: 137 additions & 0 deletions examples/selfplay/tictactoe_utils/game.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

import sys

import pygame

WIDTH = 600
HEIGHT = 600

ROWS = 3
COLS = 3
SQSIZE = WIDTH // COLS

LINE_WIDTH = 15
CIRC_WIDTH = 15
CROSS_WIDTH = 20

RADIUS = SQSIZE // 4

OFFSET = 50

# --- COLORS ---

BG_COLOR = (28, 170, 156)
LINE_COLOR = (23, 145, 135)
CIRC_COLOR = (239, 231, 200)
CROSS_COLOR = (66, 66, 66)


class Game:
def __init__(self):
self.screen = None

def reset(self):
if self.screen is None:
pygame.init()
self.screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("TIC TAC TOE")
self.screen.fill(BG_COLOR)

self.player = 1 # 1-cross #2-circles
self.running = True
self.show_lines()

# --- DRAW METHODS ---
def show_lines(self):
# bg
self.screen.fill(BG_COLOR)

# vertical
pygame.draw.line(
self.screen, LINE_COLOR, (SQSIZE, 0), (SQSIZE, HEIGHT), LINE_WIDTH
)
pygame.draw.line(
self.screen,
LINE_COLOR,
(WIDTH - SQSIZE, 0),
(WIDTH - SQSIZE, HEIGHT),
LINE_WIDTH,
)

# horizontal
pygame.draw.line(
self.screen, LINE_COLOR, (0, SQSIZE), (WIDTH, SQSIZE), LINE_WIDTH
)
pygame.draw.line(
self.screen,
LINE_COLOR,
(0, HEIGHT - SQSIZE),
(WIDTH, HEIGHT - SQSIZE),
LINE_WIDTH,
)

def draw_fig(self, row, col):
if self.player == 1:
# draw cross
# desc line
start_desc = (col * SQSIZE + OFFSET, row * SQSIZE + OFFSET)
end_desc = (col * SQSIZE + SQSIZE - OFFSET, row * SQSIZE + SQSIZE - OFFSET)
pygame.draw.line(
self.screen, CROSS_COLOR, start_desc, end_desc, CROSS_WIDTH
)
# asc line
start_asc = (col * SQSIZE + OFFSET, row * SQSIZE + SQSIZE - OFFSET)
end_asc = (col * SQSIZE + SQSIZE - OFFSET, row * SQSIZE + OFFSET)
pygame.draw.line(self.screen, CROSS_COLOR, start_asc, end_asc, CROSS_WIDTH)

elif self.player == 2:
# draw circle
center = (col * SQSIZE + SQSIZE // 2, row * SQSIZE + SQSIZE // 2)
pygame.draw.circle(self.screen, CIRC_COLOR, center, RADIUS, CIRC_WIDTH)

# --- OTHER METHODS ---

def make_move(self, row, col):
self.draw_fig(row, col)
self.next_turn()

def next_turn(self):
self.player = self.player % 2 + 1

def close(self):
self.screen.fill((0, 0, 0, 0))
pygame.display.update()
del self.screen
pygame.quit()

def get_human_action(self, agent, observation, termination, truncation, info):
action_mask = observation["action_mask"]
while True:
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.close()
sys.exit()
if event.type == pygame.MOUSEBUTTONDOWN:
pos = event.pos
row = pos[1] // SQSIZE
col = pos[0] // SQSIZE
action = row * 3 + col
if action_mask[action]:
return action
Loading

0 comments on commit 9e30d1f

Please sign in to comment.