diff --git a/gym_requirements.txt b/gym_requirements.txt index bf3d717..80d4506 100644 --- a/gym_requirements.txt +++ b/gym_requirements.txt @@ -20,8 +20,10 @@ pyparsing==2.4.2 python-dateutil==2.8.0 pytz==2019.3 scikit-learn==0.21.3 -scipy==1.3.1 +# scipy==1.3.1 six==1.12.0 sklearn==0.0 -torch==1.3.0 -torchvision==0.4.1 \ No newline at end of file +torch +# ==1.3.0 +torchvision +# ==0.4.1 diff --git a/opt_helpers/ppo_update.py b/opt_helpers/ppo_update.py index c203833..97f492a 100644 --- a/opt_helpers/ppo_update.py +++ b/opt_helpers/ppo_update.py @@ -76,11 +76,12 @@ def sl_updates(self, rollouts, agent_in, heuristic_teacher): return aggregate_actor_loss def batch_updates(self, rollouts, agent_in, go_deeper=False): - # batch_size = max(rollouts.step // 32, 1) - # num_iters = rollouts.step // batch_size - batch_size = 8 - num_iters = 4 - + if self.actor.input_dim < 10: + batch_size = max(rollouts.step // 32, 1) + num_iters = rollouts.step // batch_size + else: + num_iters = 4 + batch_size = 8 total_action_loss = torch.Tensor([0]) total_value_loss = torch.Tensor([0]) for iteration in range(num_iters): @@ -97,7 +98,7 @@ def batch_updates(self, rollouts, agent_in, go_deeper=False): deep_total_action_loss = deep_total_action_loss.cuda() samples = [rollouts.sample() for _ in range(batch_size)] samples = [sample for sample in samples if sample != False] - if len(samples) <= 0: + if len(samples) <= 1: continue state = torch.cat([sample['state'][0] for sample in samples], dim=0) action_probs = torch.Tensor([sample['action_prob'] for sample in samples])