-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
can train tictactoe
- Loading branch information
Showing
16 changed files
with
803 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,16 @@ | ||
## How to Use | ||
|
||
Users can train selfplay strategy in connect3 via: | ||
Users can train Tic-Tac-Toe via: | ||
|
||
```shell | ||
python train_selfplay.py --config selfplay_connect3.yaml | ||
python train_selfplay.py | ||
``` | ||
|
||
|
||
## Play with a trained agent | ||
|
||
Users can play with a trained agent via: | ||
|
||
```shell | ||
python human_vs_agent.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import numpy as np | ||
import torch | ||
from tictactoe_render import TictactoeRender | ||
|
||
from openrl.configs.config import create_config_parser | ||
from openrl.envs.common import make | ||
from openrl.envs.wrappers import FlattenObservation | ||
from openrl.modules.common import PPONet as Net | ||
from openrl.runners.common import PPOAgent as Agent | ||
from openrl.selfplay.wrappers.human_opponent_wrapper import HumanOpponentWrapper | ||
from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper | ||
|
||
|
||
def get_fake_env(env_num): | ||
env = make( | ||
"tictactoe_v3", | ||
env_num=env_num, | ||
asynchronous=True, | ||
opponent_wrappers=[RandomOpponentWrapper], | ||
env_wrappers=[FlattenObservation], | ||
auto_reset=False, | ||
) | ||
return env | ||
|
||
|
||
def get_human_env(env_num): | ||
env = make( | ||
"tictactoe_v3", | ||
env_num=env_num, | ||
asynchronous=True, | ||
opponent_wrappers=[TictactoeRender, HumanOpponentWrapper], | ||
env_wrappers=[FlattenObservation], | ||
auto_reset=False, | ||
) | ||
return env | ||
|
||
|
||
def human_vs_agent(): | ||
env_num = 1 | ||
fake_env = get_fake_env(env_num) | ||
env = get_human_env(env_num) | ||
cfg_parser = create_config_parser() | ||
cfg = cfg_parser.parse_args() | ||
net = Net(fake_env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
agent = Agent(net) | ||
|
||
agent.load("./ppo_agent/") | ||
|
||
total_reward = 0.0 | ||
ep_num = 5 | ||
for ep_now in range(ep_num): | ||
agent.set_env(fake_env) | ||
obs, info = env.reset() | ||
|
||
done = False | ||
step = 0 | ||
|
||
while not np.any(done): | ||
# predict next action based on the observation | ||
action, _ = agent.act(obs, info, deterministic=True) | ||
obs, r, done, info = env.step(action) | ||
step += 1 | ||
|
||
if np.any(done): | ||
total_reward += np.mean(r) > 0 | ||
print(f"{ep_now}/{ep_num}: reward: {np.mean(r)}") | ||
print(f"win rate: {total_reward / ep_num}") | ||
env.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
human_vs_agent() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2023 The OpenRL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""""" | ||
import numpy as np | ||
from tictactoe_render import TictactoeRender | ||
|
||
from openrl.envs.common import make | ||
from openrl.envs.wrappers import FlattenObservation | ||
from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper | ||
|
||
|
||
def test_env(): | ||
env_num = 1 | ||
render_model = None | ||
render_model = "human" | ||
env = make( | ||
"tictactoe_v3", | ||
render_mode=render_model, | ||
env_num=env_num, | ||
asynchronous=False, | ||
opponent_wrappers=[TictactoeRender, RandomOpponentWrapper], | ||
env_wrappers=[FlattenObservation], | ||
) | ||
|
||
obs, info = env.reset(seed=1) | ||
done = False | ||
step_num = 0 | ||
while not done: | ||
action = env.random_action(info) | ||
|
||
obs, done, r, info = env.step(action) | ||
|
||
done = np.any(done) | ||
step_num += 1 | ||
if done: | ||
print( | ||
"step:" | ||
f" {step_num},{[env_info['final_observation'] for env_info in info]}" | ||
) | ||
else: | ||
print(f"step: {step_num},{obs}") | ||
|
||
|
||
if __name__ == "__main__": | ||
test_env() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2023 The OpenRL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""""" | ||
import time | ||
from typing import Optional, Union | ||
|
||
import pygame | ||
from pettingzoo.utils.env import ActionType, AECEnv, ObsType | ||
from pettingzoo.utils.wrappers.base import BaseWrapper | ||
from tictactoe_utils.game import Game | ||
|
||
|
||
class TictactoeRender(BaseWrapper): | ||
def __init__(self, env: AECEnv): | ||
super().__init__(env) | ||
|
||
self.game = Game() | ||
self.last_action = None | ||
self.last_length = 0 | ||
self.render_mode = "game" | ||
|
||
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None): | ||
super().reset(seed, options) | ||
if self.render_mode == "game": | ||
self.game.reset() | ||
pygame.display.update() | ||
time.sleep(0.3) | ||
|
||
self.last_action = None | ||
|
||
def step(self, action: ActionType) -> None: | ||
result = super().step(action) | ||
self.last_action = action | ||
return result | ||
|
||
def observe(self, agent: str) -> Optional[ObsType]: | ||
obs = super().observe(agent) | ||
if self.last_action is not None: | ||
if self.render_mode == "game": | ||
self.game.make_move(self.last_action // 3, self.last_action % 3) | ||
pygame.display.update() | ||
self.last_action = None | ||
time.sleep(0.3) | ||
return obs | ||
|
||
def close(self): | ||
self.game.close() | ||
super().close() | ||
|
||
def set_render_mode(self, render_mode: Union[None, str]): | ||
self.render_mode = render_mode | ||
|
||
def get_human_action(self, agent, observation, termination, truncation, info): | ||
return self.game.get_human_action( | ||
agent, observation, termination, truncation, info | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2023 The OpenRL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2023 The OpenRL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""""" | ||
|
||
import sys | ||
|
||
import pygame | ||
|
||
WIDTH = 600 | ||
HEIGHT = 600 | ||
|
||
ROWS = 3 | ||
COLS = 3 | ||
SQSIZE = WIDTH // COLS | ||
|
||
LINE_WIDTH = 15 | ||
CIRC_WIDTH = 15 | ||
CROSS_WIDTH = 20 | ||
|
||
RADIUS = SQSIZE // 4 | ||
|
||
OFFSET = 50 | ||
|
||
# --- COLORS --- | ||
|
||
BG_COLOR = (28, 170, 156) | ||
LINE_COLOR = (23, 145, 135) | ||
CIRC_COLOR = (239, 231, 200) | ||
CROSS_COLOR = (66, 66, 66) | ||
|
||
|
||
class Game: | ||
def __init__(self): | ||
self.screen = None | ||
|
||
def reset(self): | ||
if self.screen is None: | ||
pygame.init() | ||
self.screen = pygame.display.set_mode((WIDTH, HEIGHT)) | ||
pygame.display.set_caption("TIC TAC TOE") | ||
self.screen.fill(BG_COLOR) | ||
|
||
self.player = 1 # 1-cross #2-circles | ||
self.running = True | ||
self.show_lines() | ||
|
||
# --- DRAW METHODS --- | ||
def show_lines(self): | ||
# bg | ||
self.screen.fill(BG_COLOR) | ||
|
||
# vertical | ||
pygame.draw.line( | ||
self.screen, LINE_COLOR, (SQSIZE, 0), (SQSIZE, HEIGHT), LINE_WIDTH | ||
) | ||
pygame.draw.line( | ||
self.screen, | ||
LINE_COLOR, | ||
(WIDTH - SQSIZE, 0), | ||
(WIDTH - SQSIZE, HEIGHT), | ||
LINE_WIDTH, | ||
) | ||
|
||
# horizontal | ||
pygame.draw.line( | ||
self.screen, LINE_COLOR, (0, SQSIZE), (WIDTH, SQSIZE), LINE_WIDTH | ||
) | ||
pygame.draw.line( | ||
self.screen, | ||
LINE_COLOR, | ||
(0, HEIGHT - SQSIZE), | ||
(WIDTH, HEIGHT - SQSIZE), | ||
LINE_WIDTH, | ||
) | ||
|
||
def draw_fig(self, row, col): | ||
if self.player == 1: | ||
# draw cross | ||
# desc line | ||
start_desc = (col * SQSIZE + OFFSET, row * SQSIZE + OFFSET) | ||
end_desc = (col * SQSIZE + SQSIZE - OFFSET, row * SQSIZE + SQSIZE - OFFSET) | ||
pygame.draw.line( | ||
self.screen, CROSS_COLOR, start_desc, end_desc, CROSS_WIDTH | ||
) | ||
# asc line | ||
start_asc = (col * SQSIZE + OFFSET, row * SQSIZE + SQSIZE - OFFSET) | ||
end_asc = (col * SQSIZE + SQSIZE - OFFSET, row * SQSIZE + OFFSET) | ||
pygame.draw.line(self.screen, CROSS_COLOR, start_asc, end_asc, CROSS_WIDTH) | ||
|
||
elif self.player == 2: | ||
# draw circle | ||
center = (col * SQSIZE + SQSIZE // 2, row * SQSIZE + SQSIZE // 2) | ||
pygame.draw.circle(self.screen, CIRC_COLOR, center, RADIUS, CIRC_WIDTH) | ||
|
||
# --- OTHER METHODS --- | ||
|
||
def make_move(self, row, col): | ||
self.draw_fig(row, col) | ||
self.next_turn() | ||
|
||
def next_turn(self): | ||
self.player = self.player % 2 + 1 | ||
|
||
def close(self): | ||
self.screen.fill((0, 0, 0, 0)) | ||
pygame.display.update() | ||
del self.screen | ||
pygame.quit() | ||
|
||
def get_human_action(self, agent, observation, termination, truncation, info): | ||
action_mask = observation["action_mask"] | ||
while True: | ||
for event in pygame.event.get(): | ||
if event.type == pygame.QUIT: | ||
self.close() | ||
sys.exit() | ||
if event.type == pygame.MOUSEBUTTONDOWN: | ||
pos = event.pos | ||
row = pos[1] // SQSIZE | ||
col = pos[0] // SQSIZE | ||
action = row * 3 + col | ||
if action_mask[action]: | ||
return action |
Oops, something went wrong.