Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PWhiddy patches applied pokemon_red.py #49

Open
wants to merge 2 commits into
base: 0.5-cleanup
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 132 additions & 29 deletions pufferlib/environments/pokemon_red.py
Original file line number Diff line number Diff line change
@@ -87,6 +87,11 @@ def __init__(
self.early_stopping = early_stopping
self.save_video = save_video
self.fast_video = fast_video
self.explore_weight = 1 if 'explore_weight' not in config else config['explore_weight']
self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore']
self.reward_scale = 1 if 'reward_scale' not in config else config['reward_scale']
self.extra_buttons = False if 'extra_buttons' not in config else config['extra_buttons']

self.save_screenshots = save_screenshots
self.video_interval = video_interval_mul * self.act_freq
self.downsample_factor = downsample_factor
@@ -100,6 +105,7 @@ def __init__(
self.instance_id = str(uuid.uuid4())[:8]

self.s_path.mkdir(exist_ok=True)
self.reset_count = 0
self.all_runs = []

# Set this in SOME subclasses
@@ -112,11 +118,15 @@ def __init__(
WindowEvent.PRESS_ARROW_RIGHT,
WindowEvent.PRESS_ARROW_UP,
WindowEvent.PRESS_BUTTON_A,
WindowEvent.PRESS_BUTTON_B,
WindowEvent.PRESS_BUTTON_START,
WindowEvent.PASS
WindowEvent.PRESS_BUTTON_B
]

if self.extra_buttons:
self.valid_actions.extend([
WindowEvent.PRESS_BUTTON_START,
WindowEvent.PASS
])

self.release_arrow = [
WindowEvent.RELEASE_ARROW_DOWN,
WindowEvent.RELEASE_ARROW_LEFT,
@@ -154,17 +164,20 @@ def __init__(
)

self.screen = self.pyboy.botsupport_manager().screen()
self.pyboy.set_emulation_speed(0 if headless else 6)
if not config['headless']:
self.pyboy.set_emulation_speed(6)
self.reset()

def reset(self, seed=None):
self.seed = seed

# restart game, skipping credits
with open(self.init_state, "rb") as f:
self.pyboy.load_state(f)

self.init_knn()

if self.use_screen_explore:
self.init_knn()
else:
self.init_map_mem

self.recent_memory = np.zeros(
(self.output_shape[1]*self.memory_height, 3),
@@ -176,6 +189,8 @@ def reset(self, seed=None):
), dtype=np.uint8,
)

self.agent_stats = []

if self.save_video:
base_dir = self.s_path / Path('rollouts')
base_dir.mkdir(exist_ok=True)
@@ -199,10 +214,20 @@ def reset(self, seed=None):
self.death_count = 0
self.step_count = 0
self.reset_count += 1

self.compute_rewards()
self.total_reward = sum(self.rewards.values())

self.levels_satisfied = False
self.base_explore = 0
self.max_opponent_level = 0
self.max_event_rew = 0
self.max_level_rew = 0
self.last_health = 1
self.total_healing_rew = 0
self.died_count = 0
self.step_count = 0
self.progress_reward = self.get_game_state_reward()
self.total_reward = sum([val for _, val in self.progress_reward.items()])
self.reset_count += 1

return self.render(), {}

@@ -213,6 +238,9 @@ def init_knn(self):
self.knn_index.init_index(
max_elements=self.num_elements, ef_construction=100, M=16)

def init_map_mem(self):
self.seen_coords = {}

def render(self, reduce_res=True, add_memory=True, update_mem=True):
game_pixels_render = self.screen.screen_ndarray() # (144, 160, 3)

@@ -240,9 +268,10 @@ def render(self, reduce_res=True, add_memory=True, update_mem=True):
), axis=0)

return game_pixels_render

def step(self, action):
self.run_action_on_emulator(action)
self.append_agent_stats(action)

self.recent_frames = np.roll(self.recent_frames, 1, axis=0)
obs_memory = self.render()
@@ -252,9 +281,15 @@ def step(self, action):
obs_flat = obs_memory[
frame_start:frame_start+self.output_shape[0], ...].flatten().astype(np.float32)

self.update_frame_knn_index(obs_flat)
if self.use_screen_explore:
self.update_frame_knn_index(obs_flat)
else:
self.update_seen_coords()

self.update_heal_reward()
new_reward, new_prog = self.compute_rewards()
self.cfg["state_params"]["health"] = self.read_hp_fraction()
self.last_health = self.read_hp_fraction()
#self.cfg["state_params"]["health"] = self.read_hp_fraction()

# shift over short term reward memory
self.recent_memory = np.roll(self.recent_memory, 3)
@@ -272,6 +307,8 @@ def step(self, action):
def run_action_on_emulator(self, action):
# press button then release after some steps
self.pyboy.send_input(self.valid_actions[action])
if not self.save_video and self.headless:
self.pyboy._rendering(False)
for i in range(self.act_freq):
# release action, so they are stateless
if i == 8:
@@ -281,14 +318,16 @@ def run_action_on_emulator(self, action):
if action > 3 and action < 6:
# release button
self.pyboy.send_input(self.release_button[action - 4])
if action == WindowEvent.PRESS_BUTTON_START:
if self.valid_actions[action] == WindowEvent.PRESS_BUTTON_START:
self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START)
if self.save_video and not self.fast_video:
self.add_video_frame()
if i == self.act_freq-1:
self.pyboy._rendering(True)
self.pyboy.tick()
if self.save_video and self.fast_video:
self.add_video_frame()

def add_video_frame(self):
self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False))
self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False))
@@ -298,14 +337,29 @@ def get_agent_stats(self, action):
y_pos = self.read_m(0xD361)
map_n = self.read_m(0xD35E)
levels = [self.read_m(a) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]]
if self.use_screen_explore:
expl = ('frames', self.knn_index.get_current_count())
else:
expl = ('coord_count', len(self.seen_coords))
return {
'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n,
'last_action': action,
'pcount': self.read_m(0xD163), 'levels': levels, 'ptypes': self.read_party(),
'hp': self.read_hp_fraction(),
'frames': self.knn_index.get_current_count(),
'deaths': self.death_count, 'badge': self.get_badges(),
'event': self.rewards["event"], 'healr': self.cfg["reward_params"]["total_healing_rew"]
#'event': self.rewards["event"], 'healr': self.cfg["reward_params"]["total_healing_rew"]
'event': self.reward_scale*self.update_max_event_rew(),
#'party_xp': self.reward_scale*0.1*sum(poke_xps),
'level': self.reward_scale*self.get_levels_reward(),
'heal': self.reward_scale*self.total_healing_rew,
'op_lvl': self.reward_scale*self.update_max_op_level(),
'dead': self.reward_scale*-0.1*self.died_count,
'badge': self.reward_scale*self.get_badges() * 5,
#'op_poke': self.reward_scale*self.max_opponent_poke * 800,
#'money': self.reward_scale* money * 3,
#'seen_poke': self.reward_scale * seen_poke_count * 400,
'explore': self.reward_scale * self.get_knn_reward()
}

def update_frame_knn_index(self, frame_vec):
@@ -327,17 +381,53 @@ def update_frame_knn_index(self, frame_vec):
self.knn_index.add_items(
frame_vec, np.array([self.knn_index.get_current_count()]))

def update_seen_coords(self):
x_pos = self.read_m(0xD362)
y_pos = self.read_m(0xD361)
map_n = self.read_m(0xD35E)
coord_string = f"x:{x_pos} y:{y_pos} m:{map_n}"
if self.get_levels_sum() >= 22 and not self.levels_satisfied:
self.levels_satisfied = True
self.base_explore = len(self.seen_coords)
self.seen_coords = {}

self.seen_coords[coord_string] = self.step_count

def compute_rewards(self):
# addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map
# https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm

self.rewards_old = self.rewards.copy()

# adds up all event flags, exclude museum ticket
event_flags_start = 0xD747
event_flags_end = 0xD886
museum_ticket = (0xD754, 0)
base_event_flags = 13
return max(
sum(
[
self.bit_count(self.read_m(i))
for i in range(event_flags_start, event_flags_end)
]
)
- base_event_flags
- int(self.read_bit(museum_ticket[0], museum_ticket[1])),
0,)

# healing reward
curr_health = self.read_hp_fraction()
self.rewards["healing"] = self.cfg["rewards"]["healing_scale"] * max(0, curr_health - self.cfg["state_params"]["health"])
if self.cfg["state_params"]["health"] <= 0: self.death_count += 1
self.cfg["state_params"]["health"] = curr_health
'''
# Not sure where to integrate
prog = self.progress_reward
# these values are only used by memory
return (prog['level'] * 100 / self.reward_scale,
self.read_hp_fraction()*2000,
prog['explore'] * 150 / (self.explore_weight * self.reward_scale))
'''

# event reward
curr_event_rew = max(sum([self.bit_count(self.read_m(i)) for i in range(0xD747, 0xD886)]) - 13, 0)
@@ -362,6 +452,9 @@ def compute_rewards(self):
self.rewards["badges"] = self.cfg["rewards"]["badge_scale"] * self.get_badges()

# exploration reward
pre_rew = self.explore_weight * 0.005
post_rew = self.explore_weight * 0.01
cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords)
curr_size = self.knn_index.get_current_count()
base = (self.cfg["state_params"]["base_explore"] if self.cfg["state_params"]["levels_satisfied"] else curr_size) * self.cfg["rewards"]["knn_pre_scale"]
post = (curr_size if self.cfg["state_params"]["levels_satisfied"] else 0) * self.cfg["rewards"]["knn_post_scale"]
@@ -430,24 +523,31 @@ def save_and_print_info(self, done, obs_memory):
print(f'\r{prog_string}', end='', flush=True)

if self.step_count % 50 == 0:
plt.imsave(
self.s_path / Path(f'curframe_{self.instance_id}.jpeg'),
self.render(reduce_res=False)
try:
plt.imsave(
self.s_path / Path(f'curframe_{self.instance_id}.jpeg'),
self.render(reduce_res=False))
except Exception as e:
print(f"Error saving iamge: {e}")
)

if self.print_rewards and done:
print('', flush=True)
if self.save_final_state:
fs_path = self.s_path / Path('final_states')
fs_path.mkdir(exist_ok=True)
plt.imsave(
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'),
obs_memory
)
plt.imsave(
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'),
self.render(reduce_res=False)
)
try:
plt.imsave(
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'),
obs_memory)
except Exception as e:
print(f"Error saving image: {e}")
try:
plt.imsave(
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'),
self.render(reduce_res=False))
except Exception as e:
print(f"Error saving iamge: {e}")

if self.save_video and done:
self.full_frame_writer.close()
@@ -476,9 +576,12 @@ def read_party(self):
def save_screenshot(self, name):
ss_dir = self.s_path / Path('screenshots')
ss_dir.mkdir(exist_ok=True)
plt.imsave(
ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'),
self.render(reduce_res=False))
try:
plt.imsave(
ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'),
self.render(reduce_res=False))
except Exception as e:
print(f"Error saving iamge: {e}")

def read_hp_fraction(self):
hp_sum = sum([self.read_hp(add) for add in [0xD16C, 0xD198, 0xD1C4, 0xD1F0, 0xD21C, 0xD248]])