Skip to content
This repository has been archived by the owner on May 6, 2021. It is now read-only.

Fix #251, ppo multidim action eval #177

Merged
merged 1 commit into from
Apr 27, 2021
Merged

Fix #251, ppo multidim action eval #177

merged 1 commit into from
Apr 27, 2021

Conversation

albheim
Copy link
Member

@albheim albheim commented Apr 27, 2021

See JuliaReinforcementLearning/ReinforcementLearning.jl#251

Tested this with the following script

using ReinforcementLearning
using Dates
using StableRNGs
using Flux
using Distributions

seed = 37

env = PendulumEnv(T = Float32, rng = StableRNG(hash(seed)))

ns = length(state_space(env))
na = 2

rng = StableRNG(seed)

policy = PPOPolicy(
    approximator = ActorCritic(
        actor = GaussianNetwork(
            pre = Chain(
                Dense(ns, 64, relu; initW = glorot_uniform(rng)),
                Dense(64, 64, relu; initW = glorot_uniform(rng)),
            ),
            μ = Chain(Dense(64, na, tanh; initW = glorot_uniform(rng))),
            logσ = Chain(Dense(64, na; initW = glorot_uniform(rng))),
        ),
        critic = Chain(
            Dense(ns, 64, relu; initW = glorot_uniform(rng)),
            Dense(64, 64, relu; initW = glorot_uniform(rng)),
            Dense(64, 1; initW = glorot_uniform(rng)),
        ),
        optimizer = ADAM(3e-4),
    ) |> cpu,
    γ = 0.99f0,
    λ = 0.95f0,
    clip_range = 0.2f0,
    max_grad_norm = 0.5f0,
    n_epochs = 10,
    n_microbatches = 32,
    actor_loss_weight = 1.0f0,
    critic_loss_weight = 0.5f0,
    entropy_loss_weight = 0.00f0,
    dist = Normal,
    rng = rng,
    update_freq = 2000,
)

policy(env)

which returned a 2x1 matrix, so it seems like now you can do it at least.

It does not handle single dimensional cases in a way that it returns a scalar anymore, which means it is a breaking change, but I feel like it is not worth keeping that special case up so for now I left it out. If you want to have it I can add it.

@codecov-commenter
Copy link

Codecov Report

Merging #177 (9237c70) into master (52a9c85) will not change coverage.
The diff coverage is 0.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #177   +/-   ##
=======================================
  Coverage   61.31%   61.31%           
=======================================
  Files          69       69           
  Lines        2505     2505           
=======================================
  Hits         1536     1536           
  Misses        969      969           
Impacted Files Coverage Δ
src/algorithms/policy_gradient/ppo.jl 85.43% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 52a9c85...9237c70. Read the comment docs.

@findmyway
Copy link
Member

Thanks!

@findmyway findmyway merged commit d78f327 into master Apr 27, 2021
@albheim albheim deleted the albheim_ppo_fix branch April 28, 2021 06:09
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants