From a2df15d4741770b7ff7e935c6cc70cadf3a30a2a Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Wed, 25 Oct 2023 11:29:14 +0800 Subject: [PATCH] add RandomAgent for Arena --- examples/arena/evaluate_more_envs.py | 6 ++--- examples/arena/run_arena.py | 14 +++++++++--- openrl/arena/agents/random_agent.py | 29 ++++++++++++++++++++++++ tests/test_arena/test_new_envs.py | 3 ++- tests/test_arena/test_reproducibility.py | 3 ++- 5 files changed, 47 insertions(+), 8 deletions(-) create mode 100644 openrl/arena/agents/random_agent.py diff --git a/examples/arena/evaluate_more_envs.py b/examples/arena/evaluate_more_envs.py index f55dc576..3b7bfe07 100644 --- a/examples/arena/evaluate_more_envs.py +++ b/examples/arena/evaluate_more_envs.py @@ -17,12 +17,12 @@ """""" from pettingzoo.butterfly import cooperative_pong_v5 -from pettingzoo.classic import connect_four_v3, go_v5, texas_holdem_no_limit_v6,rps_v2 +from pettingzoo.classic import connect_four_v3, go_v5, rps_v2, texas_holdem_no_limit_v6 from pettingzoo.mpe import simple_push_v3 - from openrl.arena import make_arena from openrl.arena.agents.local_agent import LocalAgent +from openrl.arena.agents.random_agent import RandomAgent from openrl.envs.PettingZoo.registration import register from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner @@ -79,7 +79,7 @@ def run_arena( arena = make_arena(env_id, env_wrappers=env_wrappers, use_tqdm=False) agent1 = LocalAgent("../selfplay/opponent_templates/random_opponent") - agent2 = LocalAgent("../selfplay/opponent_templates/random_opponent") + agent2 = RandomAgent() arena.reset( agents={"agent1": agent1, "agent2": agent2}, diff --git a/examples/arena/run_arena.py b/examples/arena/run_arena.py index e880884c..fdc0776a 100644 --- a/examples/arena/run_arena.py +++ b/examples/arena/run_arena.py @@ -17,6 +17,7 @@ """""" from openrl.arena import make_arena from openrl.arena.agents.local_agent import LocalAgent +from openrl.arena.agents.random_agent import RandomAgent from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner @@ -37,7 +38,7 @@ def run_arena( arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=use_tqdm) agent1 = LocalAgent("../selfplay/opponent_templates/random_opponent") - agent2 = LocalAgent("../selfplay/opponent_templates/random_opponent") + agent2 = RandomAgent() arena.reset( agents={"agent1": agent1, "agent2": agent2}, @@ -52,5 +53,12 @@ def run_arena( if __name__ == "__main__": - run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10) - # run_arena(render=False, parallel=False, seed=1, total_games=1, max_game_onetime=1,use_tqdm=False) + # run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10) + run_arena( + render=False, + parallel=False, + seed=1, + total_games=300, + max_game_onetime=1, + use_tqdm=False, + ) diff --git a/openrl/arena/agents/random_agent.py b/openrl/arena/agents/random_agent.py new file mode 100644 index 00000000..d09e5e15 --- /dev/null +++ b/openrl/arena/agents/random_agent.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from openrl.arena.agents.base_agent import BaseAgent +from openrl.selfplay.opponents.base_opponent import BaseOpponent +from openrl.selfplay.opponents.random_opponent import RandomOpponent +from openrl.selfplay.opponents.utils import load_opponent_from_path + + +class RandomAgent(BaseAgent): + def __init__(self): + super().__init__() + + def _new_agent(self) -> BaseOpponent: + return RandomOpponent() diff --git a/tests/test_arena/test_new_envs.py b/tests/test_arena/test_new_envs.py index 5dc6231e..7a5dc01d 100644 --- a/tests/test_arena/test_new_envs.py +++ b/tests/test_arena/test_new_envs.py @@ -26,6 +26,7 @@ from examples.custom_env.rock_paper_scissors import RockPaperScissors from openrl.arena import make_arena from openrl.arena.agents.local_agent import LocalAgent +from openrl.arena.agents.random_agent import RandomAgent from openrl.envs.PettingZoo.registration import register from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner @@ -82,7 +83,7 @@ def run_arena( arena = make_arena(env_id, env_wrappers=env_wrappers, use_tqdm=False) agent1 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent") - agent2 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent") + agent2 = RandomAgent() arena.reset( agents={"agent1": agent1, "agent2": agent2}, diff --git a/tests/test_arena/test_reproducibility.py b/tests/test_arena/test_reproducibility.py index 0d186ab0..9ced525c 100644 --- a/tests/test_arena/test_reproducibility.py +++ b/tests/test_arena/test_reproducibility.py @@ -22,6 +22,7 @@ from openrl.arena import make_arena from openrl.arena.agents.local_agent import LocalAgent +from openrl.arena.agents.random_agent import RandomAgent from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner @@ -41,7 +42,7 @@ def run_arena( arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=False) agent1 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent") - agent2 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent") + agent2 = RandomAgent() arena.reset( agents={"agent1": agent1, "agent2": agent2},