Skip to content
This repository has been archived by the owner on May 6, 2021. It is now read-only.

Commit

Permalink
Add REM-DQN(Random Ensemble Mixture) method (#160)
Browse files Browse the repository at this point in the history
* add some explanations

* Add REM DQN

* Add docs

* Add docs

* Modified implementation

* Some modifications

* fix conflict

Co-authored-by: Jun Tian <[email protected]>
  • Loading branch information
pilgrimygy and findmyway authored Apr 4, 2021
1 parent 2908338 commit 8668f3c
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/algorithms/dqns/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

const PERLearners = Union{PrioritizedDQNLearner,RainbowLearner,IQNLearner}

function RLBase.update!(learner::Union{DQNLearner,PERLearners}, t::AbstractTrajectory)
function RLBase.update!(learner::Union{DQNLearner,REMDQNLearner,PERLearners}, t::AbstractTrajectory)
length(t[:terminal]) - learner.sampler.n <= learner.min_replay_history && return

learner.update_step += 1
Expand Down
3 changes: 2 additions & 1 deletion src/algorithms/dqns/dqns.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
include("basic_dqn.jl")
include("dqn.jl")
include("prioritized_dqn.jl")
include("rem_dqn.jl")
include("rainbow.jl")
include("iqn.jl")
include("common.jl")
include("common.jl")
147 changes: 147 additions & 0 deletions src/algorithms/dqns/rem_dqn.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
export REMDQNLearner

mutable struct REMDQNLearner{
Tq<:AbstractApproximator,
Tt<:AbstractApproximator,
Tf,
R<:AbstractRNG,
} <: AbstractLearner
approximator::Tq
target_approximator::Tt
loss_func::Tf
min_replay_history::Int
update_freq::Int
update_step::Int
target_update_freq::Int
sampler::NStepBatchSampler
ensemble_num::Int
ensemble_method::Symbol
rng::R
# for logging
loss::Float32
end

"""
REMDQNLearner(;kwargs...)
See paper: [An Optimistic Perspective on Offline Reinforcement Learning](https://arxiv.org/abs/1907.04543)
# Keywords
- `approximator`::[`AbstractApproximator`](@ref): used to get Q-values of a state.
- `target_approximator`::[`AbstractApproximator`](@ref): similar to `approximator`, but used to estimate the target (the next state).
- `loss_func`: the loss function.
- `γ::Float32=0.99f0`: discount rate.
- `batch_size::Int=32`
- `update_horizon::Int=1`: length of update ('n' in n-step update).
- `min_replay_history::Int=32`: number of transitions that should be experienced before updating the `approximator`.
- `update_freq::Int=4`: the frequency of updating the `approximator`.
- `ensemble_num::Int=1`: the number of ensemble approximators.
- `ensemble_method::Symbol=:rand`: the method of combining Q values. ':rand' represents random ensemble mixture, and ':mean' is the average.
- `target_update_freq::Int=100`: the frequency of syncing `target_approximator`.
- `stack_size::Union{Int, Nothing}=4`: use the recent `stack_size` frames to form a stacked state.
- `traces = SARTS`, set to `SLARTSL` if you are to apply to an environment of `FULL_ACTION_SET`.
- `rng = Random.GLOBAL_RNG`
"""
function REMDQNLearner(;
approximator::Tq,
target_approximator::Tt,
loss_func::Tf,
stack_size::Union{Int,Nothing} = nothing,
γ::Float32 = 0.99f0,
batch_size::Int = 32,
update_horizon::Int = 1,
min_replay_history::Int = 32,
update_freq::Int = 1,
ensemble_num::Int = 1,
ensemble_method::Symbol = :rand,
target_update_freq::Int = 100,
traces = SARTS,
update_step = 0,
rng = Random.GLOBAL_RNG,
) where {Tq,Tt,Tf}
copyto!(approximator, target_approximator)
sampler = NStepBatchSampler{traces}(;
γ = γ,
n = update_horizon,
stack_size = stack_size,
batch_size = batch_size,
)
REMDQNLearner(
approximator,
target_approximator,
loss_func,
min_replay_history,
update_freq,
update_step,
target_update_freq,
sampler,
ensemble_num,
ensemble_method,
rng,
0.0f0,
)
end

Flux.functor(x::REMDQNLearner) = (Q = x.approximator, Qₜ = x.target_approximator),
y -> begin
x = @set x.approximator = y.Q
x = @set x.target_approximator = y.Qₜ
x
end

function (learner::REMDQNLearner)(env)
s = send_to_device(device(learner.approximator), state(env))
s = Flux.unsqueeze(s, ndims(s) + 1)
q = reshape(learner.approximator(s), :, learner.ensemble_num)
vec(mean(q, dims = 2)) |> send_to_host
end

function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple)
Q = learner.approximator
Qₜ = learner.target_approximator
γ = learner.sampler.γ
loss_func = learner.loss_func
n = learner.sampler.n
batch_size = learner.sampler.batch_size
ensemble_num = learner.ensemble_num
D = device(Q)
# Build a convex polygon to make a combination of multiple Q-value estimates as a Q-value estimate.
if learner.ensemble_method == :rand
convex_polygon = rand(Float32, (1, ensemble_num))
else
convex_polygon = ones(Float32, (1, ensemble_num))
end
convex_polygon ./= sum(convex_polygon)
convex_polygon = send_to_device(D, convex_polygon)

s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS)
a = CartesianIndex.(a, 1:batch_size)

target_q = Qₜ(s′)
target_q = convex_polygon .* reshape(target_q, :, ensemble_num, batch_size)
target_q = dropdims(sum(target_q, dims=2), dims=2)

if haskey(batch, :next_legal_actions_mask)
l′ = send_to_device(D, batch[:next_legal_actions_mask])
target_q .+= ifelse.(l′, 0.0f0, typemin(Float32))
end

q′ = dropdims(maximum(target_q; dims = 1), dims = 1)
G = r .+ γ^n .* (1 .- t) .* q′

gs = gradient(params(Q)) do
q = Q(s)
q = convex_polygon .* reshape(q, :, ensemble_num, batch_size)
q = dropdims(sum(q, dims=2), dims=2)[a]

loss = loss_func(G, q)
ignore() do
learner.loss = loss
end
loss
end

update!(Q, gs)
end

92 changes: 92 additions & 0 deletions src/experiments/rl_envs/JuliaRL_REMDQN_CartPole.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
function RLCore.Experiment(
::Val{:JuliaRL},
::Val{:REMDQN},
::Val{:CartPole},
::Nothing;
save_dir = nothing,
seed = 123,
)
if isnothing(save_dir)
t = Dates.format(now(), "yyyy_mm_dd_HH_MM_SS")
save_dir = joinpath(pwd(), "checkpoints", "JuliaRL_REMDQN_CartPole_$(t)")
end

lg = TBLogger(joinpath(save_dir, "tb_log"), min_level = Logging.Info)
rng = StableRNG(seed)

env = CartPoleEnv(; T = Float32, rng = rng)
ns, na = length(state(env)), length(action_space(env))
ensemble_num = 6

agent = Agent(
policy = QBasedPolicy(
learner = REMDQNLearner(
approximator = NeuralNetworkApproximator(
model = Chain(
# Multi-head method, please refer to "https://github.com/google-research/batch_rl/tree/b55ba35ebd2381199125dd77bfac9e9c59a64d74/batch_rl/multi_head".
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
Dense(128, 128, relu; initW = glorot_uniform(rng)),
Dense(128, na * ensemble_num; initW = glorot_uniform(rng)),
) |> cpu,
optimizer = ADAM(),
),
target_approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
Dense(128, 128, relu; initW = glorot_uniform(rng)),
Dense(128, na * ensemble_num; initW = glorot_uniform(rng)),
) |> cpu,
),
loss_func = huber_loss,
stack_size = nothing,
batch_size = 32,
update_horizon = 1,
min_replay_history = 100,
update_freq = 1,
target_update_freq = 100,
ensemble_num = ensemble_num,
ensemble_method = :rand,
rng = rng,
),
explorer = EpsilonGreedyExplorer(
kind = :exp,
ϵ_stable = 0.01,
decay_steps = 500,
rng = rng,
),
),
trajectory = CircularArraySARTTrajectory(
capacity = 1000,
state = Vector{Float32} => (ns,),
),
)

stop_condition = StopAfterStep(10_000)

total_reward_per_episode = TotalRewardPerEpisode()
time_per_step = TimePerStep()
hook = ComposedHook(
total_reward_per_episode,
time_per_step,
DoEveryNStep() do t, agent, env
if agent.policy.learner.update_step % agent.policy.learner.update_freq == 0
with_logger(lg) do
@info "training" loss = agent.policy.learner.loss
end
end
end,
DoEveryNEpisode() do t, agent, env
with_logger(lg) do
@info "training" reward = total_reward_per_episode.rewards[end] log_step_increment =
0
end
end,
)

description = """
This experiment uses the `REMDQNLearner` method with three dense layers to approximate the Q value.
The testing environment is CartPoleEnv.
"""

Experiment(agent, env, stop_condition, hook, description)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end

@testset "training" begin
mktempdir() do dir
for method in (:BasicDQN, :BC, :DQN, :PrioritizedDQN, :Rainbow, :IQN, :VPG)
for method in (:BasicDQN, :BC, :DQN, :PrioritizedDQN, :Rainbow, :REMDQN, :IQN, :VPG)
res = run(
Experiment(
Val(:JuliaRL),
Expand Down

0 comments on commit 8668f3c

Please sign in to comment.