Skip to content

Commit

Permalink
big changes
Browse files Browse the repository at this point in the history
  • Loading branch information
whyb committed Dec 26, 2024
1 parent ea9b22a commit 95b93eb
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 109 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.pth
*.onnx
*.pt
/__pycache__/*
135 changes: 76 additions & 59 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
import os
import math
import torch
import torch.nn as nn
import numpy as np
import random


BOARD_SIZE = 8 # 定义棋盘大小
WIN_CONDITION = 5 # 胜利条件


# 游戏环境
class Gomoku:
def __init__(self):
self.board = np.zeros((BOARD_SIZE, BOARD_SIZE), dtype=int)
self.current_player = 1
self.winning_line = []


def reset(self):
self.board.fill(0)
self.current_player = 1
self.winning_line = []


def is_winning_move(self, x, y):
# 检查五子连珠的胜利条件
def count_consecutive(player, dx, dy):
Expand All @@ -33,6 +38,7 @@ def count_consecutive(player, dx, dy):
break
return count, line


player = self.board[x, y]
directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
for dx, dy in directions:
Expand All @@ -43,81 +49,85 @@ def count_consecutive(player, dx, dy):
return True
return False


def step(self, action):
# 解析动作坐标, 将传入的action转换为棋盘上的坐标
# 解析动作坐标, 将传入的 action 转换为棋盘上的坐标
x, y = action // BOARD_SIZE, action % BOARD_SIZE
# 检查目标位置是否已被占用
if self.board[x, y] != 0:
return -1, True
if self.board[x, y]!= 0:
return -1, True, 0
# 落子
self.board[x, y] = self.current_player
if self.is_winning_move(x, y):
return self.current_player, True
if self.current_player == 1:
return 1, True, 10000 # Player 1 五子连珠获胜
else:
return 2, True, -10000 # Player 2 五子连珠获胜

# 切换到另外一个棋手 1变2,2变1
# 切换到另外一个棋手 1 变 2,2 变 1
self.current_player = 3 - self.current_player

# 中间奖励score机制
# 中间奖励 score 机制
score = self.evaluate_board()
return score, False
return self.board[x, y], False, score

def evaluate_board(self):
score = 0
directions = [(1, 0), (0, 1), (1, 1), (1, -1)]

def evaluate_line(player, x, y, dx, dy):
count = 1
block = 0
for step in range(1, WIN_CONDITION):
nx, ny = x + dx * step, y + dy * step
if 0 <= nx < BOARD_SIZE and 0 <= ny < BOARD_SIZE:
if self.board[nx, ny] == player:
count += 1
elif self.board[nx, ny] == 0:
break
else:
block += 1
break
else:
block += 1
break
for step in range(1, WIN_CONDITION):
nx, ny = x - dx * step, y - dy * step
if 0 <= nx < BOARD_SIZE and 0 <= ny < BOARD_SIZE:
if self.board[nx, ny] == player:
count += 1
elif self.board[nx, ny] == 0:
break
else:
block += 1
break
def evaluate_board(self):
def count_consecutive(player, x, y, dx, dy):
"""
计算在特定方向上玩家的连续棋子数
:param player: 玩家编号(1 或 2)
:param x: 起始 x 坐标
:param y: 起始 y 坐标
:param dx: x 方向增量
:param dy: y 方向增量
:return: 连续棋子数
"""
count = 0
for step in range(WIN_CONDITION):
nx = x + dx * step
ny = y + dy * step
if 0 <= nx < BOARD_SIZE and 0 <= ny < BOARD_SIZE and self.board[nx, ny] == player:
count += 1
else:
block += 1
break
return count, block
return count

for i in range(BOARD_SIZE):
for j in range(BOARD_SIZE):
if self.board[i, j] != 0:
player = self.board[i, j]
for dx, dy in directions:
count, block = evaluate_line(player, i, j, dx, dy)
if count >= WIN_CONDITION:
score += 10000
elif count == 4 and block == 0:
score += 500
elif count == 4 and block == 1:
score += 100
elif count == 3 and block == 0:
score += 50
elif count == 3 and block == 1:
score += 10
elif count == 2 and block == 0:
score += 5
elif count == 2 and block == 1:
score += 1

directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
score = 0
for x in range(BOARD_SIZE):
for y in range(BOARD_SIZE):
player = self.board[x, y]
if player == 0:
continue
for dx, dy in directions:
count = count_consecutive(player, x, y, dx, dy)
if count == 5: # 五子连珠
score += 10000
elif count == 4: # 四子连珠
score += 500
elif count == 3: # 三子连珠
score += 100
elif count == 2: # 二子连珠
score += 10

return score


def simulate_move(self, action):
x, y = action // BOARD_SIZE, action % BOARD_SIZE
if self.board[x, y]!= 0:
return False
self.board[x, y] = self.current_player
self.current_player = 3 - self.current_player
return True


def evaluate_state(self):
return self.evaluate_board()


def print_board(self):
for i in range(BOARD_SIZE):
row = ''
Expand All @@ -129,6 +139,7 @@ def print_board(self):
print(row)
print()


# Version #1
class GomokuNetV1(nn.Module):
def __init__(self):
Expand All @@ -137,12 +148,14 @@ def __init__(self):
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, BOARD_SIZE * BOARD_SIZE)


def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x


# 卷积神经网络(CNN)
class GomokuNetV2(nn.Module):
def __init__(self):
Expand All @@ -152,6 +165,7 @@ def __init__(self):
self.fc1 = nn.Linear(128 * BOARD_SIZE * BOARD_SIZE, 256)
self.fc2 = nn.Linear(256, BOARD_SIZE * BOARD_SIZE)


def forward(self, x):
x = torch.relu(self.conv1(x.view(-1, 1, BOARD_SIZE, BOARD_SIZE)))
x = torch.relu(self.conv2(x))
Expand All @@ -160,16 +174,19 @@ def forward(self, x):
x = self.fc2(x)
return x


def get_valid_action(logits, board, epsilon=0.1):
logits = logits.flatten() # 展平logits,确保其形状为(BOARD_SIZE * BOARD_SIZE,)
logits = logits.flatten() # 展平 logits,确保其形状为(BOARD_SIZE * BOARD_SIZE,)
valid_actions = [(logits[i].item(), i) for i in range(BOARD_SIZE * BOARD_SIZE) if board[i // BOARD_SIZE, i % BOARD_SIZE] == 0]
valid_actions.sort(reverse=True, key=lambda x: x[0]) # 根据 logits 从大到小排序


if random.random() < epsilon:
return random.choice(valid_actions)[1] if valid_actions else -1
else:
return valid_actions[0][1] if valid_actions else -1


def load_model_if_exists(model, file_path):
if os.path.exists(file_path):
model.load_state_dict(torch.load(file_path))
Expand Down
Loading

0 comments on commit 95b93eb

Please sign in to comment.