-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
1,833 additions
and
1,729 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
""" mjai protocol bot implementations""" | ||
from .common import * | ||
from .bot import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" Bot represents a mjai protocol bot | ||
implement wrappers for supportting different bot types | ||
""" | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
import common.mj_helper as mj_helper | ||
from common.utils import BOT_TYPE | ||
|
||
def reaction_convert_meta(reaction:dict): | ||
""" add meta_options to reaction """ | ||
if 'meta' in reaction: | ||
meta = reaction['meta'] | ||
reaction['meta_options'] = mj_helper.meta_to_options(meta) | ||
|
||
class Bot(ABC): | ||
""" Bot Interface class | ||
bot follows mjai protocol | ||
ref: https://mjai.app/docs/highlevel-api | ||
Note: reach msg is implemented differently. | ||
Reach msg has reach_dahai attached, which is a 'dahai' msg, indicating the dahai action after reach | ||
msgs have 'meta_options', which is a translation of 'meta' into list of (mjai tile, weight)""" | ||
|
||
def __init__(self, bot_type:BOT_TYPE, name:str="Bot") -> None: | ||
self.type = bot_type | ||
self.name = name | ||
self._initialized:bool = False | ||
self.seat:int = None | ||
|
||
def init_bot(self, seat:int): | ||
""" Initialize the bot before the game starts. Bot must be initialized before a new game""" | ||
self.seat = seat | ||
self._init_bot_impl() | ||
self._initialized = True | ||
|
||
@property | ||
def initialized(self) -> bool: | ||
""" return True if bot is initialized""" | ||
return self._initialized | ||
|
||
@abstractmethod | ||
def _init_bot_impl(self): | ||
""" Initialize the bot before the game starts.""" | ||
|
||
@abstractmethod | ||
def react(self, input_msg:dict) -> dict | None: | ||
""" input mjai msg and get bot output if any, or None if not""" | ||
|
||
def react_batch(self, input_list:list[dict]) -> dict | None: | ||
""" input list of mjai msg and get the last output, if any""" | ||
|
||
# default implementation is to iterate and feed to bot | ||
if len(input_list) == 0: | ||
return None | ||
for msg in input_list[:-1]: | ||
msg['can_act'] = False | ||
self.react(msg) | ||
last_reaction = self.react(input_list[-1]) | ||
return last_reaction | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
""" Bot Mortal Local """ | ||
|
||
from pathlib import Path | ||
import threading | ||
import json | ||
import libriichi | ||
from mjai.engine import get_engine | ||
from common.utils import ModelFileException | ||
from common.mj_helper import MJAI_TYPE | ||
from common.log_helper import LOGGER | ||
from .bot import Bot, BOT_TYPE, reaction_convert_meta | ||
|
||
# mjai Bot class from rust library | ||
# pylint: disable=no-member | ||
MjaiBot = libriichi.mjai.Bot | ||
|
||
class BotMortalLocal(Bot): | ||
""" Mortal model based mjai bot""" | ||
def __init__(self, model_file:str) -> None: | ||
""" params: | ||
model_file: path to the mortal model file | ||
""" | ||
super().__init__(BOT_TYPE.LOCAL, "Local Mortal Bot - " + model_file) | ||
self.model_file = model_file | ||
if not Path(self.model_file).exists(): | ||
raise ModelFileException(f"Cannot find model file:{self.model_file}") | ||
|
||
self.mjai_bot:MjaiBot = None | ||
|
||
self.ignore_next_turn_self_reach:bool = False | ||
self.str_input_history:list = [] | ||
# thread lock for mjai.bot access | ||
# "mutable borrow" issue when running multiple methods at the same time | ||
self.lock = threading.Lock() | ||
|
||
def _init_bot_impl(self): | ||
engine = get_engine(self.model_file) | ||
self.mjai_bot = MjaiBot(engine, self.seat) | ||
self.str_input_history.clear() | ||
|
||
def react(self, input_msg:dict) -> dict: | ||
|
||
if self.ignore_next_turn_self_reach: # ignore repetitive self reach. only for the very next msg | ||
if input_msg['type'] == MJAI_TYPE.REACH and input_msg['actor'] == self.seat: | ||
LOGGER.debug("Ignoring repetitive self reach msg, reach msg already sent to AI last turn") | ||
return None | ||
self.ignore_next_turn_self_reach = False | ||
|
||
str_input = json.dumps(input_msg) | ||
self.str_input_history.append(str_input) | ||
with self.lock: | ||
react_str = self.mjai_bot.react(str_input) | ||
if react_str is None: | ||
return None | ||
reaction = json.loads(react_str) | ||
reaction_convert_meta(reaction) | ||
# Special treatment for self reach output msg | ||
# mjai only outputs dahai msg after the reach msg | ||
if reaction['type'] == MJAI_TYPE.REACH and reaction['actor'] == self.seat: # Self reach | ||
# get the subsequent dahai message, | ||
# appeding it to the reach reaction msg as 'reach_dahai' key | ||
LOGGER.debug("Send reach msg to get reach_dahai. Cannot go back to unreach!") | ||
# TODO make a clone of mjai_bot so reach can be tested to get dahai without affecting the game | ||
|
||
reach_msg = {'type': MJAI_TYPE.REACH, 'actor': self.seat} | ||
reach_dahai_str = self.mjai_bot.react(json.dumps(reach_msg)) | ||
reach_dahai = json.loads(reach_dahai_str) | ||
reaction_convert_meta(reach_dahai) | ||
reaction['reach_dahai'] = reach_dahai | ||
self.ignore_next_turn_self_reach = True # ignore very next reach msg | ||
return reaction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
""" Bot for mjapi""" | ||
|
||
import time | ||
from common.settings import Settings | ||
from common.log_helper import LOGGER | ||
from common.utils import random_str | ||
from common.mj_helper import MJAI_TYPE | ||
from .mjapi import MJAPI_Client | ||
|
||
from .bot import Bot, BOT_TYPE, reaction_convert_meta | ||
|
||
|
||
|
||
class BotMjapi(Bot): | ||
""" Bot using mjapi online API""" | ||
batch_size = 24 | ||
retries = 3 | ||
retry_interval = 1 | ||
bound = 256 | ||
|
||
""" MJAPI based mjai bot""" | ||
def __init__(self, setting:Settings) -> None: | ||
super().__init__(BOT_TYPE.MJAPI, "MJAPI Bot - " + setting.mjapi_url) | ||
self.settings = setting | ||
self.mjapi = MJAPI_Client(self.settings.mjapi_url) | ||
self._login_or_reg() | ||
self.id = -1 | ||
self.ignore_next_turn_self_reach:bool = False | ||
|
||
def _login_or_reg(self): | ||
if not self.settings.mjapi_user: | ||
self.settings.mjapi_user = random_str(6) | ||
LOGGER.info("Set random mjapi username:%s", self.settings.mjapi_user) | ||
try: | ||
self.mjapi.login(self.settings.mjapi_user, self.settings.mjapi_secret) | ||
except Exception as e: | ||
LOGGER.warning("Error login: %s", e) | ||
# try register | ||
res_reg = self.mjapi.register(self.settings.mjapi_user) | ||
self.settings.mjapi_secret = res_reg['secret'] | ||
self.settings.save_json() | ||
LOGGER.info("Registered new user [%s] with MJAPI. User name and secret saved to settings.", self.settings.mjapi_user) | ||
self.mjapi.login(self.settings.mjapi_user, self.settings.mjapi_secret) | ||
|
||
model_list = self.mjapi.list_models() | ||
if not model_list: | ||
raise RuntimeError("No models available in MJAPI") | ||
self.settings.mjapi_models = model_list | ||
if self.settings.mjapi_model_select in model_list: | ||
# OK | ||
pass | ||
else: | ||
LOGGER.debug( | ||
"mjapi selected model %s N/A, using last one from available list %s", | ||
self.settings.mjapi_model_select, model_list[-1]) | ||
self.settings.mjapi_model_select = model_list[-1] | ||
self.model_name = self.settings.mjapi_model_select | ||
LOGGER.info("Login to MJAPI successful with user: %s, model_name=%s", self.settings.mjapi_user, self.model_name) | ||
|
||
def __del__(self): | ||
self.mjapi.stop_bot() | ||
self.mjapi.logout() | ||
|
||
def _init_bot_impl(self): | ||
self.mjapi.start_bot(self.seat, BotMjapi.bound, self.model_name) | ||
self.id = -1 | ||
|
||
def _process_reaction(self, reaction, recurse): | ||
if reaction: | ||
reaction_convert_meta(reaction) | ||
else: | ||
return None | ||
|
||
# process self reach | ||
if recurse and reaction['type'] == MJAI_TYPE.REACH and reaction['actor'] == self.seat: | ||
LOGGER.debug("Send reach msg to get reach_dahai.") | ||
reach_msg = {'type': MJAI_TYPE.REACH, 'actor': self.seat} | ||
reach_dahai = self.react(reach_msg, recurse=False) | ||
reaction['reach_dahai'] = self._process_reaction(reach_dahai, False) | ||
self.ignore_next_turn_self_reach = True | ||
|
||
return reaction | ||
|
||
def react(self, input_msg:dict, recurse=True) -> dict | None: | ||
# input_msg['can_act'] = True | ||
msg_type = input_msg['type'] | ||
if msg_type in [MJAI_TYPE.START_GAME, MJAI_TYPE.END_GAME, MJAI_TYPE.END_KYOKU]: | ||
# ignore no effect msgs | ||
return None | ||
if self.ignore_next_turn_self_reach: | ||
if msg_type == MJAI_TYPE.REACH and input_msg['actor'] == self.seat: | ||
LOGGER.debug("Ignoring repetitive self reach msg, reach msg already sent to AI last turn") | ||
return None | ||
self.ignore_next_turn_self_reach = False | ||
|
||
old_id = self.id | ||
err = None | ||
self.id = (self.id + 1) % BotMjapi.bound | ||
reaction = None | ||
for _ in range(BotMjapi.retries): | ||
try: | ||
reaction = self.mjapi.act(self.id, input_msg) | ||
err = None | ||
break | ||
except BaseException as e: | ||
err = e | ||
time.sleep(BotMjapi.retry_interval) | ||
if err: | ||
self.id = old_id | ||
raise err | ||
return self._process_reaction(reaction, recurse) | ||
|
||
def react_batch(self, input_list: list[dict]) -> dict | None: | ||
if self.ignore_next_turn_self_reach and len(input_list) > 0: | ||
if input_list[0]['type'] == MJAI_TYPE.REACH and input_list[0]['actor'] == self.seat: | ||
LOGGER.debug("Ignoring repetitive self reach msg, reach msg already sent to AI last turn") | ||
input_list = input_list[1:] | ||
self.ignore_next_turn_self_reach = False | ||
if len(input_list) == 0: | ||
return None | ||
num_batches = (len(input_list) - 1) // BotMjapi.batch_size + 1 | ||
reaction = None | ||
for (i, start) in enumerate(range(0, len(input_list), BotMjapi.batch_size)): | ||
reaction = self._react_batch_impl( | ||
input_list[start:start + BotMjapi.batch_size], | ||
can_act=(i + 1 == num_batches)) | ||
return reaction | ||
|
||
def _react_batch_impl(self, input_list, can_act): | ||
if len(input_list) == 0: | ||
return None | ||
batch_data = [] | ||
|
||
old_id = self.id | ||
err = None | ||
for (i, msg) in enumerate(input_list): | ||
self.id = (self.id + 1) % BotMjapi.bound | ||
if i + 1 == len(input_list) and not can_act: | ||
msg = msg.copy() | ||
msg['can_act'] = False | ||
action = {'seq': self.id, 'data': msg} | ||
batch_data.append(action) | ||
reaction = None | ||
for _ in range(BotMjapi.retries): | ||
try: | ||
reaction = self.mjapi.batch(batch_data) | ||
err = None | ||
break | ||
except BaseException as e: | ||
err = e | ||
time.sleep(BotMjapi.retry_interval) | ||
if err: | ||
self.id = old_id | ||
raise err | ||
return self._process_reaction(reaction, True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
""" Bot factory""" | ||
|
||
from common.settings import Settings | ||
from common.utils import MODEL_FOLDER, sub_file | ||
from .bot import Bot, BOT_TYPE | ||
from .bot_local import BotMortalLocal | ||
from .bot_mjapi import BotMjapi | ||
|
||
|
||
def get_bot(settings:Settings) -> Bot: | ||
""" create the Bot instance based on settings""" | ||
if settings.model_type == BOT_TYPE.LOCAL.value: | ||
bot = BotMortalLocal(sub_file(MODEL_FOLDER, settings.model_file)) | ||
elif settings.model_type == BOT_TYPE.MJAPI.value: | ||
bot = BotMjapi(settings) | ||
else: | ||
raise ValueError(f"Unknown model type: {settings.model_type}") | ||
|
||
return bot |
Oops, something went wrong.