Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 9, 2024
1 parent b077da1 commit 312b82e
Show file tree
Hide file tree
Showing 28 changed files with 29 additions and 29 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval
cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval
final = collected_frames >= collector.total_frames
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cur_test_frame = (i * frames_per_batch) // evaluation_interval
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821
to_log = {"train/loss": loss_vals["loss"]}

# Evaluation
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
max_steps=eval_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def main(cfg: "DictConfig"): # noqa: F821
}

# Evaluation
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
inference_policy.eval()
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cur_test_frame = (i * frames_per_batch) // eval_iter
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get and log evaluation rewards and eval time
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
cur_test_frame = (i * frames_per_batch) // test_interval
final = current_frames >= collector.total_frames
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dqn/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get and log evaluation rewards and eval time
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
cur_test_frame = (i * frames_per_batch) // test_interval
final = current_frames >= collector.total_frames
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def compile_rssms(module):
# Evaluation
if (i % eval_iter) == 0:
# Real env
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_rollout = test_env.rollout(
eval_rollout_steps,
policy,
Expand All @@ -298,7 +298,7 @@ def compile_rssms(module):
log_metrics(logger, eval_metrics, collected_frames)
# Simulated env
if model_based_env_eval is not None:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_rollout = model_based_env_eval.rollout(
eval_rollout_steps,
policy,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def _dreamer_make_actor_real(
SafeProbabilisticModule(
in_keys=["loc", "scale"],
out_keys=[action_key],
default_interaction_type=InteractionType.MODE,
default_interaction_type=InteractionType.DETERMINISTIC,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
spec=CompositeSpec(
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/impala/impala_multi_node_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/impala/impala_multi_node_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/impala/impala_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/iql/discrete_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/iql/iql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
i * frames_in_batch
) // cfg_logger_test_interval:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down

0 comments on commit 312b82e

Please sign in to comment.