Skip to content
This repository has been archived by the owner on May 6, 2021. It is now read-only.

Fix bug in multi action ppo #169

Merged
merged 2 commits into from
Apr 14, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/algorithms/policy_gradient/ppo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,11 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
if AC.actor isa GaussianNetwork
μ, σ = AC.actor(s)
if ndims(a) == 2
log_p′ₐ = sum(normlogpdf(μ, σ, a), dims = 1)
log_p′ₐ = vec(sum(normlogpdf(μ, σ, a), dims = 1))
else
log_p′ₐ = normlogpdf(μ, σ, a)
end
entropy_loss = mean((log(2.0f0π) + 1) / 2 .+ sum(log.(σ), dims = 1))
entropy_loss = mean(size(σ, 1) * (log(2.0f0π) + 1) .+ sum(log.(σ), dims = 1)) / 2
else
# actor is assumed to return discrete logits
logit′ = AC.actor(s)
Expand All @@ -280,7 +280,6 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
log_p′ₐ = log_p′[CartesianIndex.(a, 1:length(a))]
entropy_loss = -sum(p′ .* log_p′) * 1 // size(p′, 2)
end

ratio = exp.(log_p′ₐ .- log_p)
surr1 = ratio .* adv
surr2 = clamp.(ratio, 1.0f0 - clip_range, 1.0f0 + clip_range) .* adv
Expand Down