-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsettlementTools.py
154 lines (128 loc) · 5.28 KB
/
settlementTools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from __future__ import annotations
from typing import Callable, TYPE_CHECKING
if TYPE_CHECKING:
from Node import Node
import numpy as np
import globals
import worldTools
from MCTS.mcts import MCTS
from RootNode import RootNode
def runSearcher(
rootNode: Node,
rng: np.random.Generator = np.random.default_rng(),
targetName: str = '',
iterationLimit: int = 40000,
explorationConstant: float = 1 / np.sqrt(2),
clearActionCache: bool = False,
) -> list[Node]:
print(f'Start MCTS for {targetName} (iterationLimit: {iterationLimit}, explorationConstant: {explorationConstant})')
searcher = MCTS(
iterationLimit=iterationLimit,
rolloutPolicy=mctsRolloutPolicy,
explorationConstant=explorationConstant,
rng=rng,
)
searcher.search(initialState=rootNode)
bestNodes: list[Node] = searcher.getBestRoute()
nodeList: list[Node] = []
for node in bestNodes:
if isinstance(node, RootNode):
continue
nodeList.append(node)
print(f'Finished running MCTS for {targetName}. A trace of {len(nodeList)} was found.')
finalizeTrace(nodeList, targetName, clearActionCache)
return nodeList
def explorationConstantWorldScale() -> float:
return worldTools.buildAreaSqrt() / 10
def mctsRolloutPolicy(state: Node, rng: np.random.Generator = np.random.default_rng()) -> float:
while not state.isTerminal():
try:
actions = state.getPossibleActions()
if len(actions) > 1:
# Bias actions towards lower costs structures
actionCostSum = sum(actions)
weights = []
for action in actions:
weights.append(1 - (action.cost / actionCostSum))
weights = weights / np.sum(weights)
selectedAction = rng.choice(actions, p=weights)
else:
selectedAction = actions[0]
except IndexError:
raise Exception(f'Non-terminal state has no possible actions: {state}')
state = state.takeAction(selectedAction)
return state.getReward()
def finalizeTrace(nodeList: list[Node], routeName: str = None, clearActionCache: bool = False):
for index, node in enumerate(nodeList):
nextNode = None
if index + 1 < len(nodeList):
nextNode = nodeList[index + 1]
node.finalize(nextNode, routeName, clearActionCache)
def findRandomConnectionNode(
rng: np.random.Generator = np.random.default_rng(),
nodeList: list[Node] = None,
) -> Node:
if nodeList is None or len(nodeList) == 0:
raise Exception('Could not fit node with open connection slot')
candidateNodes: list[Node] = []
for finalizedNode in nodeList:
if finalizedNode.hasOpenSlot:
candidateNodes.append(finalizedNode)
if len(candidateNodes) == 0:
raise Exception('Could not fit node with open connection slot')
return rng.choice(candidateNodes)
def findRandomConnectionNodeGlobal(
rng: np.random.Generator = np.random.default_rng()
) -> Node:
candidateNodes: list[Node] = []
for finalizedNode in globals.nodeList:
if finalizedNode.hasOpenSlot:
candidateNodes.append(finalizedNode)
if len(candidateNodes) == 0:
raise Exception('Could not fit node with open connection slot')
return rng.choice(candidateNodes)
def findConnectionNodeByRewardValue(
rewardFunction: Callable[[Node], float] = None,
nodeList: list[Node] = None,
) -> Node:
if nodeList is None or len(nodeList) == 0:
raise Exception('Could not fit node with open connection slot')
candidateNodes: list[Node] = []
rewards: list[float] = []
for finalizedNode in nodeList:
if finalizedNode.hasOpenSlot:
candidateNodes.append(finalizedNode)
rewards.append(rewardFunction(finalizedNode))
if len(candidateNodes) == 0:
raise Exception('Could not fit node with open connection slot')
return candidateNodes[np.argmin(rewards)]
def findConnectionNodeByRewardValueGlobal(
rewardFunction: Callable[[Node], float] = None
) -> Node:
candidateNodes: list[Node] = []
rewards: list[float] = []
for finalizedNode in globals.nodeList:
if finalizedNode.hasOpenSlot:
candidateNodes.append(finalizedNode)
rewards.append(rewardFunction(finalizedNode))
if len(candidateNodes) == 0:
raise Exception('Could not fit node with open connection slot')
return candidateNodes[np.argmin(rewards)]
def placeNodes():
# Set random tick speed to zero to prevent any trees from growing while structures are being placed.
globals.editor.runCommandGlobal('gamerule randomTickSpeed 0')
for node in globals.nodeList:
node.doPreProcessingSteps()
globals.editor.flushBuffer()
for node in globals.nodeList:
node.place()
for node in globals.nodeList:
node.doPostProcessingSteps()
# Set random tick speed to 300 for a little bit to speed up tree growth
globals.editor.runCommandGlobal('gamerule randomTickSpeed 300')
globals.editor.flushBuffer()
globals.editor.runCommandGlobal('gamerule randomTickSpeed 3')
globals.editor.runCommandGlobal('kill @e[type=item]')
print(globals.nodeList)
# Clear nodeList to prevent placing nodes multiple times.
globals.nodeList.clear()