Skip to content

Commit

Permalink
Changed variable name (Pascal to snake case)
Browse files Browse the repository at this point in the history
  • Loading branch information
brahimdriss committed Jul 11, 2023
1 parent 5ed8121 commit 9dfd91a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions rlberry/agents/torch/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def reset(self):
)
self._policy_old.load_state_dict(self._policy.state_dict())

self.MseLoss = nn.MSELoss()
self.mse_loss = nn.MSELoss()

self.memory = ReplayBuffer(max_replay_size=self.batch_size, rng=self.rng)
self.memory.setup_entry("states", dtype=np.float32)
Expand Down Expand Up @@ -303,7 +303,7 @@ def _update(self):
pg_loss = -logprobs * advantages
loss = (
pg_loss
+ 0.5 * self.MseLoss(state_values, rewards)
+ 0.5 * self.mse_loss(state_values, rewards)
- self.entr_coef * dist_entropy
)

Expand Down
18 changes: 9 additions & 9 deletions rlberry/agents/torch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def reset(self, **kwargs):
self.q2.parameters(), **self.q_optimizer_kwargs
)
# Define the loss
self.MseLoss = nn.MSELoss()
self.mse_loss = nn.MSELoss()

# Automatic entropy tuning
if self.autotune_alpha:
Expand Down Expand Up @@ -399,10 +399,10 @@ def _update(self):
)
# Compute the next state's Q-values
q1_next_target = self.q1_target(
torch.cat([next_state, next_state_actions], dim=1)
torch.cat([next_state, next_state_actions], dim=-1)
)
q2_next_target = self.q2_target(
torch.cat([next_state, next_state_actions], dim=1)
torch.cat([next_state, next_state_actions], dim=-1)
)
# Compute Q targets:
# - Compute the minimum Q-values between Q1 and Q2
Expand All @@ -418,10 +418,10 @@ def _update(self):
).view(-1)

# Compute Q loss
q1_v = self.q1(torch.cat([states, actions], dim=1))
q2_v = self.q2(torch.cat([states, actions], dim=1))
q1_loss_v = self.MseLoss(q1_v.squeeze(), next_q_value)
q2_loss_v = self.MseLoss(q2_v.squeeze(), next_q_value)
q1_v = self.q1(torch.cat([states, actions], dim=-1))
q2_v = self.q2(torch.cat([states, actions], dim=-1))
q1_loss_v = self.mse_loss(q1_v.squeeze(), next_q_value)
q2_loss_v = self.mse_loss(q2_v.squeeze(), next_q_value)
q_loss_v = q1_loss_v + q2_loss_v

# Update Q networks
Expand All @@ -443,8 +443,8 @@ def _update(self):
states.detach().cpu().numpy()
)
# Compute the next state's Q-values
q_out_v1 = self.q1(torch.cat([states, state_action], dim=1))
q_out_v2 = self.q2(torch.cat([states, state_action], dim=1))
q_out_v1 = self.q1(torch.cat([states, state_action], dim=-1))
q_out_v2 = self.q2(torch.cat([states, state_action], dim=-1))
# Select the minimum Q to reduce over estimation and improve stability
q_out_v = torch.min(q_out_v1, q_out_v2)
# Compute policy loss:
Expand Down

0 comments on commit 9dfd91a

Please sign in to comment.