-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_agent.py
170 lines (152 loc) · 5.95 KB
/
train_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Parent fine-tuning agent class.
"""
import os
import numpy as np
from omegaconf import OmegaConf
import torch
import hydra
import logging
import wandb
import random
log = logging.getLogger(__name__)
from env.gym_utils import make_async
class TrainAgent:
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.device = cfg.device
self.seed = cfg.get("seed", 42)
random.seed(self.seed)
np.random.seed(self.seed)
torch.manual_seed(self.seed)
# Wandb
self.use_wandb = cfg.wandb is not None
if cfg.wandb is not None:
wandb.init(
# entity=cfg.wandb.entity,
project=cfg.wandb.project,
name=cfg.wandb.run,
config=OmegaConf.to_container(cfg, resolve=True),
)
# Make vectorized env
self.env_name = cfg.env.name
env_type = cfg.env.get("env_type", None)
self.venv = make_async(
cfg.env.name,
env_type=env_type,
num_envs=cfg.env.n_envs,
asynchronous=True,
max_episode_steps=cfg.env.max_episode_steps,
wrappers=cfg.env.get("wrappers", None),
robomimic_env_cfg_path=cfg.get("robomimic_env_cfg_path", None),
shape_meta=cfg.get("shape_meta", None),
use_image_obs=cfg.env.get("use_image_obs", False),
render=cfg.env.get("render", False),
render_offscreen=cfg.env.get("save_video", False),
obs_dim=cfg.obs_dim,
action_dim=cfg.action_dim,
**cfg.env.specific if "specific" in cfg.env else {},
)
if not env_type == "furniture":
self.venv.seed(
[self.seed + i for i in range(cfg.env.n_envs)]
) # otherwise parallel envs might have the same initial states!
# isaacgym environments do not need seeding
self.n_envs = cfg.env.n_envs
self.n_cond_step = cfg.cond_steps
self.obs_dim = cfg.obs_dim
self.action_dim = cfg.action_dim
self.act_steps = cfg.act_steps
self.horizon_steps = cfg.horizon_steps
self.max_episode_steps = cfg.env.max_episode_steps
self.reset_at_iteration = cfg.env.get("reset_at_iteration", True)
self.save_full_observations = cfg.env.get("save_full_observations", False)
self.furniture_sparse_reward = (
cfg.env.specific.get("sparse_reward", False)
if "specific" in cfg.env
else False
) # furniture specific, for best reward calculation
# Batch size for gradient update
self.batch_size: int = cfg.train.batch_size
# Build model and load checkpoint
self.model = hydra.utils.instantiate(cfg.model)
# Training params
self.itr = 0
self.n_train_itr = cfg.train.n_train_itr
self.val_freq = cfg.train.val_freq
self.force_train = cfg.train.get("force_train", False)
self.n_steps = cfg.train.n_steps
self.best_reward_threshold_for_success = (
len(self.venv.pairs_to_assemble)
if env_type == "furniture"
else cfg.env.best_reward_threshold_for_success
)
self.max_grad_norm = cfg.train.get("max_grad_norm", None)
# Logging, rendering, checkpoints
self.logdir = cfg.logdir
self.render_dir = os.path.join(self.logdir, "render")
self.checkpoint_dir = os.path.join(self.logdir, "checkpoint")
self.result_path = os.path.join(self.logdir, "result.pkl")
os.makedirs(self.render_dir, exist_ok=True)
os.makedirs(self.checkpoint_dir, exist_ok=True)
self.save_trajs = cfg.train.get("save_trajs", False)
self.log_freq = cfg.train.get("log_freq", 1)
self.save_model_freq = cfg.train.save_model_freq
self.render_freq = cfg.train.render.freq
self.n_render = cfg.train.render.num
self.render_video = cfg.env.get("save_video", False)
assert self.n_render <= self.n_envs, "n_render must be <= n_envs"
assert not (
self.n_render <= 0 and self.render_video
), "Need to set n_render > 0 if saving video"
self.traj_plotter = (
hydra.utils.instantiate(cfg.train.plotter)
if "plotter" in cfg.train
else None
)
def run(self):
pass
def save_model(self):
"""
saves model to disk; no ema
"""
data = {
"itr": self.itr,
"model": self.model.state_dict(),
}
savepath = os.path.join(self.checkpoint_dir, f"state_{self.itr}.pt")
torch.save(data, savepath)
log.info(f"Saved model to {savepath}")
def load(self, itr):
"""
loads model from disk
"""
loadpath = os.path.join(self.checkpoint_dir, f"state_{itr}.pt")
data = torch.load(loadpath, weights_only=True)
self.itr = data["itr"]
self.model.load_state_dict(data["model"])
def reset_env_all(self, verbose=False, options_venv=None, **kwargs):
if options_venv is None:
options_venv = [
{k: v for k, v in kwargs.items()} for _ in range(self.n_envs)
]
obs_venv = self.venv.reset_arg(options_list=options_venv)
# convert to OrderedDict if obs_venv is a list of dict
if isinstance(obs_venv, list):
obs_venv = {
key: np.stack([obs_venv[i][key] for i in range(self.n_envs)])
for key in obs_venv[0].keys()
}
if verbose:
for index in range(self.n_envs):
logging.info(
f"<-- Reset environment {index} with options {options_venv[index]}"
)
return obs_venv
def reset_env(self, env_ind, verbose=False):
task = {}
obs = self.venv.reset_one_arg(env_ind=env_ind, options=task)
if verbose:
logging.info(f"<-- Reset environment {env_ind} with task {task}")
return obs