Skip to content
This repository has been archived by the owner on May 6, 2021. It is now read-only.

Commit

Permalink
improve basicdqn (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
findmyway authored Nov 22, 2020
1 parent ff95729 commit 5d0780b
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions src/algorithms/dqns/basic_dqn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,33 @@ end
function RLBase.update!(learner::BasicDQNLearner, T::AbstractTrajectory)
length(T[:terminal]) < learner.min_replay_history && return

inds = rand(learner.rng, 1:length(T[:terminal]), learner.batch_size)

batch = (
state = consecutive_view(T[:state], inds),
action = consecutive_view(T[:action], inds),
reward = consecutive_view(T[:reward], inds),
terminal = consecutive_view(T[:terminal], inds),
next_state = consecutive_view(T[:next_state], inds),
)

update!(learner, batch)
end

function RLBase.update!(learner::BasicDQNLearner, batch::NamedTuple)

Q = learner.approximator
D = device(Q)
γ = learner.γ
loss_func = learner.loss_func
batch_size = learner.batch_size

inds = rand(learner.rng, 1:length(T[:terminal]), learner.batch_size)
batch_size = nframes(batch.terminal)

s = send_to_device(D, consecutive_view(T[:state], inds))
a = consecutive_view(T[:action], inds)
r = send_to_device(D, consecutive_view(T[:reward], inds))
t = send_to_device(D, consecutive_view(T[:terminal], inds))
s′ = send_to_device(D, consecutive_view(T[:next_state], inds))
s = send_to_device(D, batch.state)
a = batch.action
r = send_to_device(D, batch.reward)
t = send_to_device(D, batch.terminal)
s′ = send_to_device(D, batch.next_state)

a = CartesianIndex.(a, 1:batch_size)

Expand Down

0 comments on commit 5d0780b

Please sign in to comment.