Skip to content

Commit

Permalink
Fix a problem with the recent version of PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
ikostrikov2 committed Mar 14, 2017
1 parent abe0de2 commit 15dd5e5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# pytorch-a3c

NEED TO USE V-0.1.9 (or lower) OF PYTORCH, AND NOT V-0.1.10 BECAUSE OF THIS ISSUE:
https://discuss.pytorch.org/t/problem-on-variable-grad-data/957/7

This is a PyTorch implementation of Asynchronous Advantage Actor Critic (A3C) from ["Asynchronous Methods for Deep Reinforcement Learning"](https://arxiv.org/pdf/1602.01783v1.pdf).

This implementation is inspired by [Universe Starter Agent](https://github.com/openai/universe-starter-agent).
Expand All @@ -14,7 +11,7 @@ Contributions are very welcome. If you know how to make this code better, don't

## Usage
```
python main.py --env-name "PongDeterministic-v3" --num-processes 16
OMP_NUM_THREADS=1 python main.py --env-name "PongDeterministic-v3" --num-processes 16
```

This code runs evaluation in a separate thread in addition to 16 processes.
Expand Down
1 change: 0 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
args = parser.parse_args()

torch.manual_seed(args.seed)
torch.set_num_threads(1)

env = create_atari_env(args.env_name)
shared_model = ActorCritic(
Expand Down
10 changes: 10 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def __init__(self, num_inputs, action_space):
self.lstm.bias_hh.data.fill_(0)

self.train()
self.__dummy_backprob()

def __dummy_backprob(self):
# See: https://discuss.pytorch.org/t/problem-on-variable-grad-data/957/7
# An ugly hack until there is a better solution.
inputs = Variable(torch.randn(1, 1, 42, 42))
hx, cx = Variable(torch.randn(1, 256)), Variable(torch.randn(1, 256))
outputs = self((inputs, (hx, cx)))
loss = (outputs[0].mean() + outputs[1].mean()) * 0.0
loss.backward()

def forward(self, inputs):
inputs, (hx, cx) = inputs
Expand Down

0 comments on commit 15dd5e5

Please sign in to comment.