Skip to content
This repository has been archived by the owner on May 6, 2021. It is now read-only.

Commit

Permalink
add GridWorlds environments (#152)
Browse files Browse the repository at this point in the history
* ignore vim generated temp files

* add JuliaRL_BasicDQN_EmptyRoom experiment

* add per-step penalty and max-timeout per episode

* add test for JuliaRL_BasicDQN_EmptyRoom

* add JuliaRL_BasicDQN_EmptyRoom to README

* add note on importing GridWorlds
  • Loading branch information
Sid-Bhatia-0 authored Mar 1, 2021
1 parent 7913db6 commit c5439e6
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 3 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
.DS_Store
/Manifest.toml
/dev/
**/checkpoints/
**/checkpoints/

# add vim generated temp files
*~
*.swp
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ Zygote = "0.5, 0.6"
julia = "1.4"

[extras]
GridWorlds = "e15a9946-cd7f-4d03-83e2-6c30bacb0043"
OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "ReinforcementLearningEnvironments", "OpenSpiel"]
test = ["Test", "ReinforcementLearningEnvironments", "OpenSpiel", "GridWorlds"]
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Some built-in experiments are exported to help new users to easily run benchmark
- ``E`JuliaRL_DeepCFR_OpenSpiel(leduc_poker)` ``
- ``E`JuliaRL_DQN_SnakeGame` ``
- ``E`JuliaRL_BC_CartPole` ``
- ``E`JuliaRL_BasicDQN_EmptyRoom` ``
- ``E`Dopamine_DQN_Atari(pong)` ``
- ``E`Dopamine_Rainbow_Atari(pong)` ``
- ``E`Dopamine_IQN_Atari(pong)` ``
Expand Down Expand Up @@ -87,7 +88,7 @@ julia> run(E`rlpyt_PPO_Atari(pong)`) # the Atari environment is provided in Arc
- Experiments on `CartPole` usually run faster with CPU only due to the overhead of sending data between CPU and GPU.
- It shouldn't surprise you that our experiments on `CartPole` are much faster than those written in Python. The secret is that our environment is written in Julia!
- Remember to set `JULIA_NUM_THREADS` to enable multi-threading when using algorithms like `A2C` and `PPO`.
- Experiments on `Atari` (`OpenSpiel`, `SnakeGame`) are only available after you have `ArcadeLearningEnvironment.jl` (`OpenSpiel.jl`, `SnakeGame.jl`) installed and `using ArcadeLearningEnvironment` (`using OpenSpiel`, `using SnakeGame`).
- Experiments on `Atari` (`OpenSpiel`, `SnakeGame`, `GridWorlds`) are only available after you have `ArcadeLearningEnvironment.jl` (`OpenSpiel.jl`, `SnakeGame.jl`, `GridWorlds.jl`) installed and `using ArcadeLearningEnvironment` (`using OpenSpiel`, `using SnakeGame`, `import GridWorlds`).

### Speed

Expand Down
3 changes: 3 additions & 0 deletions src/ReinforcementLearningZoo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ function __init__()
@require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include(
"experiments/open_spiel/open_spiel.jl",
)
@require GridWorlds = "e15a9946-cd7f-4d03-83e2-6c30bacb0043" include(
"experiments/gridworlds/gridworlds.jl",
)
end
end

Expand Down
84 changes: 84 additions & 0 deletions src/experiments/gridworlds/JuliaRL_BasicDQN_EmptyRoom.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
function RLCore.Experiment(
::Val{:JuliaRL},
::Val{:BasicDQN},
::Val{:EmptyRoom},
::Nothing;
seed = 123,
save_dir = nothing,
)
if isnothing(save_dir)
t = Dates.format(now(), "yyyy_mm_dd_HH_MM_SS")
save_dir = joinpath(pwd(), "checkpoints", "JuliaRL_BasicDQN_EmptyRoom$(t)")
end
log_dir = joinpath(save_dir, "tb_log")
lg = TBLogger(log_dir, min_level = Logging.Info)
rng = StableRNG(seed)

inner_env = GridWorlds.EmptyRoom(rng = rng)
action_space_mapping = x -> Base.OneTo(length(RLBase.action_space(inner_env)))
action_mapping = i -> RLBase.action_space(inner_env)[i]
env = RLEnvs.ActionTransformedEnv(inner_env, action_space_mapping = action_space_mapping, action_mapping = action_mapping)
env = RLEnvs.StateOverriddenEnv(env, x -> vec(Float32.(x)))
env = RewardOverriddenEnv(env, x -> x - convert(typeof(x), 0.01))
env = MaxTimeoutEnv(env, 240)

ns, na = length(state(env)), length(action_space(env))
agent = Agent(
policy = QBasedPolicy(
learner = BasicDQNLearner(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
Dense(128, 128, relu; initW = glorot_uniform(rng)),
Dense(128, na; initW = glorot_uniform(rng)),
) |> cpu,
optimizer = ADAM(),
),
batch_size = 32,
min_replay_history = 100,
loss_func = huber_loss,
rng = rng,
),
explorer = EpsilonGreedyExplorer(
kind = :exp,
ϵ_stable = 0.01,
decay_steps = 500,
rng = rng,
),
),
trajectory = CircularArraySARTTrajectory(
capacity = 1000,
state = Vector{Float32} => (ns,),
),
)

stop_condition = StopAfterStep(10_000)

total_reward_per_episode = TotalRewardPerEpisode()
time_per_step = TimePerStep()
hook = ComposedHook(
total_reward_per_episode,
time_per_step,
DoEveryNStep() do t, agent, env
with_logger(lg) do
@info "training" loss = agent.policy.learner.loss
end
end,
DoEveryNEpisode() do t, agent, env
with_logger(lg) do
@info "training" reward = total_reward_per_episode.rewards[end] log_step_increment =
0
end
end,
)

description = """
This experiment uses three dense layers to approximate the Q value.
The testing environment is EmptyRoom.
You can view the runtime logs with `tensorboard --logdir $log_dir`.
Some useful statistics are stored in the `hook` field of this experiment.
"""

Experiment(agent, env, stop_condition, hook, description)
end
3 changes: 3 additions & 0 deletions src/experiments/gridworlds/gridworlds.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import .GridWorlds

include("JuliaRL_BasicDQN_EmptyRoom.jl")
19 changes: 19 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Statistics
using Random
using OpenSpiel
using StableRNGs
import GridWorlds

function get_optimal_kuhn_policy= 0.2)
TabularRandomPolicy(
Expand Down Expand Up @@ -96,6 +97,24 @@ end
@test e.hook[1][] == e.hook[0][] == [0.0]
end

@testset "GridWorlds" begin
mktempdir() do dir
for method in (:BasicDQN,)
res = run(
Experiment(
Val(:JuliaRL),
Val(method),
Val(:EmptyRoom),
nothing;
save_dir = joinpath(dir, "EmptyRoom", string(method)),
),
)
@info "stats for $method" avg_reward = mean(res.hook[1].rewards) avg_fps =
1 / mean(res.hook[2].times)
end
end
end

@testset "TabularCFR" begin
e = E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)`
run(e)
Expand Down

0 comments on commit c5439e6

Please sign in to comment.