Skip to content

Commit

Permalink
e3b new formula
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Suarez committed Jan 30, 2025
1 parent 4418024 commit 4416823
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def evaluate(data):
with profile.eval_misc:
value = value.flatten()
actions = actions.cpu().numpy()
mask = torch.as_tensor(mask)# * policy.mask)
mask = torch.as_tensor(mask)
o = o if config.cpu_offload else o_device
experience.store(o, value, actions, logprob, r, d, env_id, mask)

Expand Down Expand Up @@ -411,7 +411,7 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, hidden_size,
self.dones=torch.zeros(batch_size, pin_memory=pin)
self.truncateds=torch.zeros(batch_size, pin_memory=pin)
self.values=torch.zeros(batch_size, pin_memory=pin)
self.e3b_inv = 1*torch.eye(hidden_size).repeat(lstm_total_agents, 1, 1).to(device)
self.e3b_inv = 10*torch.eye(hidden_size).repeat(lstm_total_agents, 1, 1).to(device)

self.actions_np = np.asarray(self.actions)
self.logprobs_np = np.asarray(self.logprobs)
Expand Down
12 changes: 6 additions & 6 deletions pufferlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,16 @@ def decode_actions(self, hidden, lookup, concat=True, e3b=None):
batch = hidden.shape[0]
return probs, value

intrinsic_reward = None
b = None
if e3b is not None:
phi = hidden.detach()
intrinsic_reward = (phi.unsqueeze(1) @ e3b @ phi.unsqueeze(2))
e3b = 0.95*e3b - (phi.unsqueeze(2) @ phi.unsqueeze(1))/(1 + intrinsic_reward)
intrinsic_reward = intrinsic_reward.squeeze()
intrinsic_reward = 0.1*torch.clamp(intrinsic_reward, -1, 1)
u = phi.unsqueeze(1) @ e3b
b = u @ phi.unsqueeze(2)
e3b = 0.99*e3b - (u.mT @ u) / (1 + b)
b = b.squeeze()

actions = self.decoder(hidden)
return actions, value, e3b, intrinsic_reward
return actions, value, e3b, b

class LSTMWrapper(nn.Module):
def __init__(self, env, policy, input_size=128, hidden_size=128, num_layers=1):
Expand Down

0 comments on commit 4416823

Please sign in to comment.