forked from smearle/control-pcgrl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenvs.py
98 lines (89 loc) · 4.1 KB
/
envs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from gym_pcgrl import wrappers, conditional_wrappers
#from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines.common.vec_env import SubprocVecEnv, DummyVecEnv
from utils import RenderMonitor, get_map_width
from gym import spaces
from pdb import set_trace as TT
def make_env(env_name, representation, rank=0, log_dir=None, **kwargs):
'''
Return a function that will initialize the environment when called.
'''
max_step = kwargs.get('max_step', None)
render = kwargs.get('render', False)
conditional = kwargs.get('conditional', False)
evaluate = kwargs.get('evaluate', False)
ALP_GMM = kwargs.get('alp_gmm', False)
map_width = kwargs.get('map_width')
# evo_compare = kwargs.get('evo_compare', False)
def _thunk():
if representation == 'wide':
env = wrappers.ActionMapImagePCGRLWrapper(env_name, **kwargs)
if representation == 'cellular':
# env = wrappers.CAWrapper(env_name, **kwargs)
env = wrappers.CAactionWrapper(env_name, **kwargs)
# TT()
else:
crop_size = kwargs.get('cropped_size', 28)
env = wrappers.CroppedImagePCGRLWrapper(env_name, crop_size, **kwargs)
# if evo_compare:
# # FIXME: THIS DOES NOT WORK
# # Give a little wiggle room from targets, to allow for some diversity
# if "binary" in env_name:
# path_trg = env.unwrapped._prob.static_trgs['path-length']
# env.unwrapped._prob.static_trgs.update({'path-length': (path_trg - 20, path_trg)})
# elif "zelda" in env_name:
# path_trg = env.unwrapped._prob.static_trgs['path-length']
# env.unwrapped._prob.static_trgs.update({'path-length': (path_trg - 40, path_trg)})
# elif "sokoban" in env_name:
# sol_trg = env.unwrapped._prob.static_trgs['sol-length']
# env.unwrapped._prob.static_trgs.update({'sol-length': (sol_trg - 10, sol_trg)})
# elif "smb" in env_name:
# pass
# else:
# raise NotImplementedError
env.configure(**kwargs)
if max_step is not None:
env = wrappers.MaxStep(env, max_step)
if log_dir is not None and kwargs.get('add_bootstrap', False):
env = wrappers.EliteBootStrapping(env,
os.path.join(log_dir, "bootstrap{}/".format(rank)))
env = conditional_wrappers.ParamRew(env, ctrl_metrics=kwargs.pop('cond_metrics', []), **kwargs)
if not evaluate:
if not ALP_GMM:
env = conditional_wrappers.UniformNoiseyTargets(env, **kwargs)
else:
env = conditional_wrappers.ALPGMMTeacher(env, **kwargs)
# it not conditional, the ParamRew wrapper should just be fixed at default static targets
if render or log_dir is not None and len(log_dir) > 0:
# RenderMonitor must come last
env = RenderMonitor(env, rank, log_dir, **kwargs)
return env
return _thunk
def make_vec_envs(env_name, representation, log_dir, **kwargs):
'''
Prepare a vectorized environment using a list of 'make_env' functions.
'''
map_width = get_map_width(env_name)
kwargs['map_width'] = map_width
n_cpu = kwargs.pop('n_cpu')
if n_cpu > 1:
env_lst = []
for i in range(n_cpu):
env_lst.append(make_env(env_name, representation, i, log_dir, **kwargs))
env = SubprocVecEnv(env_lst)
else:
env = DummyVecEnv([make_env(env_name, representation, 0, log_dir, **kwargs)])
# A hack :~)
dummy_env = make_env(env_name, representation, -1, None, **kwargs)()
action_space = dummy_env.action_space
if isinstance(action_space, spaces.Discrete):
n_tools = action_space.n // (map_width ** 2)
elif isinstance(action_space, spaces.MultiDiscrete):
n_tools = action_space.nvec[2]
elif isinstance(action_space, spaces.Box):
n_tools = action_space.shape[0] // map_width ** 2
else:
raise Exception
dummy_env.env.close()
del(dummy_env)
return env, action_space, n_tools