Skip to content

Commit

Permalink
remove hard-coded commenting from batch size parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
ProLoNets committed Jul 30, 2021
1 parent 79c63c5 commit ea426e0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
8 changes: 5 additions & 3 deletions gym_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ pyparsing==2.4.2
python-dateutil==2.8.0
pytz==2019.3
scikit-learn==0.21.3
scipy==1.3.1
# scipy==1.3.1
six==1.12.0
sklearn==0.0
torch==1.3.0
torchvision==0.4.1
torch
# ==1.3.0
torchvision
# ==0.4.1
13 changes: 7 additions & 6 deletions opt_helpers/ppo_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ def sl_updates(self, rollouts, agent_in, heuristic_teacher):
return aggregate_actor_loss

def batch_updates(self, rollouts, agent_in, go_deeper=False):
# batch_size = max(rollouts.step // 32, 1)
# num_iters = rollouts.step // batch_size
batch_size = 8
num_iters = 4

if self.actor.input_dim < 10:
batch_size = max(rollouts.step // 32, 1)
num_iters = rollouts.step // batch_size
else:
num_iters = 4
batch_size = 8
total_action_loss = torch.Tensor([0])
total_value_loss = torch.Tensor([0])
for iteration in range(num_iters):
Expand All @@ -97,7 +98,7 @@ def batch_updates(self, rollouts, agent_in, go_deeper=False):
deep_total_action_loss = deep_total_action_loss.cuda()
samples = [rollouts.sample() for _ in range(batch_size)]
samples = [sample for sample in samples if sample != False]
if len(samples) <= 0:
if len(samples) <= 1:
continue
state = torch.cat([sample['state'][0] for sample in samples], dim=0)
action_probs = torch.Tensor([sample['action_prob'] for sample in samples])
Expand Down

0 comments on commit ea426e0

Please sign in to comment.