Skip to content
This repository has been archived by the owner on May 6, 2021. It is now read-only.

Commit

Permalink
Fix #251, ppo multidim action eval (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
albheim authored Apr 27, 2021
1 parent 52a9c85 commit d78f327
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/algorithms/policy_gradient/ppo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ RLBase.prob(p::PPOPolicy, env::MultiThreadEnv) = prob(p, state(env))
function RLBase.prob(p::PPOPolicy, env::AbstractEnv)
s = state(env)
s = Flux.unsqueeze(s, ndims(s) + 1)
prob(p, s)[1]
prob(p, s)
end

(p::PPOPolicy)(env::MultiThreadEnv) = rand.(p.rng, prob(p, env))
(p::PPOPolicy)(env::AbstractEnv) = rand(p.rng, prob(p, env))
(p::PPOPolicy)(env::AbstractEnv) = rand.(p.rng, prob(p, env))

function (agent::Agent{<:PPOPolicy})(env::MultiThreadEnv)
dist = prob(agent.policy, env)
Expand Down

0 comments on commit d78f327

Please sign in to comment.