Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add stock trading env #428

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .cspell/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@
"boxoban",
"DATADEPS",
"umaze",
"pybullet"
"pybullet",
"turbulences"
],
"ignoreWords": [],
"minWordLength": 5,
Expand Down
17 changes: 17 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,21 @@

### ReinforcementLearningBase.jl

#### v0.9.6

- Implement `Base.:(==)` for `Space`. [#428](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/428)

#### v0.9.5

- Add default `Base.:(==)` and `Base.hash` method for `AbstractEnv`. [#348](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/348)

### ReinforcementLearningCore.jl

#### v0.8.3

- Add extra two optional keyword arguments (`min_σ` and `max_σ`) in
`GaussianNetwork` to clip the output of `logσ`. [#428](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/428)

#### v0.8.2

- Add GaussianNetwork and DuelingNetwork into ReinforcementLearningCore.jl as general components. [#370](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/370)
Expand Down Expand Up @@ -60,6 +69,14 @@

### ReinforcementLearningEnvironments.jl

#### v0.6.3

- Add `StockTradingEnv` from the paper [Deep Reinforcement Learning for
Automated Stock Trading: An Ensemble
Strategy](https://github.com/AI4Finance-LLC/Deep-Reinforcement-Learning-for-Automated-Stock-Trading-Ensemble-Strategy-ICAIF-2020).
This environment is a good testbed for multi-continuous action space
algorithms. [#428](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/428)

#### v0.6.2

- Add `SequentialEnv` environment wrapper to turn a simultaneous environment
Expand Down
2 changes: 1 addition & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
version = "0.8.2"

[[ReinforcementLearningEnvironments]]
deps = ["IntervalSets", "MacroTools", "Markdown", "Random", "ReinforcementLearningBase", "Requires", "StatsBase"]
deps = ["DelimitedFiles", "IntervalSets", "LinearAlgebra", "MacroTools", "Markdown", "Pkg", "Random", "ReinforcementLearningBase", "Requires", "StatsBase"]
path = "../src/ReinforcementLearningEnvironments"
uuid = "25e41dd2-4622-11e9-1641-f1adca772921"
version = "0.6.2"
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningBase/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReinforcementLearningBase"
uuid = "e575027e-6cd6-5018-9292-cdc6200d2b44"
authors = ["Johanni Brea <[email protected]>", "Jun Tian <[email protected]>"]
version = "0.9.6"
version = "0.9.7"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
1 change: 1 addition & 0 deletions src/ReinforcementLearningBase/src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct Space{T}
s::T
end

Base.:(==)(x::Space, y::Space) = x.s == y.s
Base.similar(s::Space, args...) = Space(similar(s.s, args...))
Base.getindex(s::Space, args...) = getindex(s.s, args...)
Base.setindex!(s::Space, args...) = setindex!(s.s, args...)
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReinforcementLearningCore"
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
authors = ["Jun Tian <[email protected]>"]
version = "0.8.2"
version = "0.8.3"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,18 @@ end
#####

"""
GaussianNetwork(;pre=identity, μ, logσ)
GaussianNetwork(;pre=identity, μ, logσ, min_σ=0f0, max_σ=Inf32)

Returns `μ` and `logσ` when called.
Create a distribution to sample from
using `Normal.(μ, exp.(logσ))`.
Returns `μ` and `logσ` when called. Create a distribution to sample from using
`Normal.(μ, exp.(logσ))`. `min_σ` and `max_σ` are used to clip the output from
`logσ`.
"""
Base.@kwdef struct GaussianNetwork{P,U,S}
pre::P = identity
μ::U
logσ::S
min_σ::Float32 = 0f0
max_σ::Float32 = Inf32
end

Flux.@functor GaussianNetwork
Expand All @@ -91,7 +93,8 @@ This function is compatible with a multidimensional action space. When outputtin
"""
function (model::GaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=false, is_return_log_prob::Bool=false)
x = model.pre(state)
μ, logσ = model.μ(x), model.logσ(x)
μ, raw_logσ = model.μ(x), model.logσ(x)
logσ = clamp.(raw_logσ, log(model.min_σ), log(model.max_σ))
if is_sampling
π_dist = Normal.(μ, exp.(logσ))
z = rand.(rng, π_dist)
Expand Down
6 changes: 6 additions & 0 deletions src/ReinforcementLearningEnvironments/Artifacts.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[stock_trading_data]
git-tree-sha1 = "c2ef05aa70df44749bd43b2ab9a558ea6829b32b"

[[stock_trading_data.download]]
url = "https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/releases/download/v0.9.0/stock_trading_data.tar.gz"
sha256 = "2abc589a9dfb5b2134ee531152bd361b08629938ea3bf53fe56270517d732c89"
2 changes: 1 addition & 1 deletion src/ReinforcementLearningEnvironments/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
deps = ["AbstractTrees", "CommonRLInterface", "Markdown", "Random", "Test"]
path = "../ReinforcementLearningBase"
uuid = "e575027e-6cd6-5018-9292-cdc6200d2b44"
version = "0.9.5"
version = "0.9.6"

[[Requires]]
deps = ["UUIDs"]
Expand Down
7 changes: 5 additions & 2 deletions src/ReinforcementLearningEnvironments/Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
name = "ReinforcementLearningEnvironments"
uuid = "25e41dd2-4622-11e9-1641-f1adca772921"
authors = ["Jun Tian <[email protected]>"]
version = "0.6.2"
version = "0.6.3"

[deps]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -30,4 +33,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ArcadeLearningEnvironment", "OpenSpiel", "OrdinaryDiffEq", "PyCall", "StableRNGs", "Statistics", "Test"]
test = ["ArcadeLearningEnvironment", "OpenSpiel", "OrdinaryDiffEq", "PyCall", "StableRNGs", "Statistics", "Test"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include("wrappers/wrappers.jl")
include("examples/examples.jl")
include("non_interactive/non_interactive.jl")
include("wrappers/wrappers.jl")
include("3rd_party/structs.jl")
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
export StockTradingEnv, StockTradingEnvWithTurbulence

using Pkg.Artifacts
using DelimitedFiles
using LinearAlgebra:dot
using IntervalSets

function load_default_stock_data(s)
if s == "prices.csv" || s == "features.csv"
data, _ = readdlm(joinpath(artifact"stock_trading_data", s), ',', header=true)
collect(data')
elseif s == "turbulence.csv"
readdlm(joinpath(artifact"stock_trading_data", "turbulence.csv")) |> vec
else
@error "unknown dataset $s"
end
end

mutable struct StockTradingEnv{F<:AbstractMatrix{Float64}, P<:AbstractMatrix{Float64}} <: AbstractEnv
features::F
prices::P
HMAX_NORMALIZE::Float32
TRANSACTION_FEE_PERCENT::Float32
REWARD_SCALING::Float32
initial_account_balance::Float32
state::Vector{Float32}
total_cost::Float32
day::Int
first_day::Int
last_day::Int
daily_reward::Float32
end

_n_stocks(env::StockTradingEnv) = size(env.prices, 1)
_prices(env::StockTradingEnv) = @view(env.state[2:1+_n_stocks(env)])
_holds(env::StockTradingEnv) = @view(env.state[2+_n_stocks(env):_n_stocks(env)*2+1])
_features(env::StockTradingEnv) = @view(env.state[_n_stocks(env)*2+2:end])
_balance(env::StockTradingEnv) = @view env.state[1]
_total_asset(env::StockTradingEnv) = env.state[1] + dot(_prices(env), _holds(env))

"""
StockTradingEnv(;kw...)

This environment is originally provided in [Deep Reinforcement Learning for Automated Stock Trading: An Ensemble Strategy](https://github.com/AI4Finance-LLC/Deep-Reinforcement-Learning-for-Automated-Stock-Trading-Ensemble-Strategy-ICAIF-2020)

# Keyword Arguments

- `initial_account_balance=1_000_000`.
"""
function StockTradingEnv(;
initial_account_balance=1_000_000f0,
features=nothing,
prices=nothing,
first_day=nothing,
last_day=nothing,
HMAX_NORMALIZE = 100f0,
TRANSACTION_FEE_PERCENT = 0.001f0,
REWARD_SCALING = 1f-4
)
prices = isnothing(prices) ? load_default_stock_data("prices.csv") : prices
features = isnothing(features) ? load_default_stock_data("features.csv") : features

@assert size(prices, 2) == size(features, 2)

first_day = isnothing(first_day) ? 1 : first_day
last_day = isnothing(last_day) ? size(prices, 2) : last_day
day = first_day

# [balance, stock_prices..., stock_holds..., features...]
state = zeros(Float32, 1 + size(prices, 1) * 2 + size(features, 1))

env = StockTradingEnv(
features,
prices,
HMAX_NORMALIZE,
TRANSACTION_FEE_PERCENT,
REWARD_SCALING,
initial_account_balance,
state,
0f0,
day,
first_day,
last_day,
0f0
)

_balance(env)[] = initial_account_balance
_prices(env) .= @view prices[:, day]
_features(env) .= @view features[:, day]

env
end

function (env::StockTradingEnv)(actions)
init_asset = _total_asset(env)

# sell first
for (i, s) in enumerate(actions)
if s < 0
sell = min(-env.HMAX_NORMALIZE * s, _holds(env)[i])
_holds(env)[i] -= sell
gain = _prices(env)[i] * sell
cost = gain * env.TRANSACTION_FEE_PERCENT
_balance(env)[] += gain - cost
env.total_cost += cost
end
end

# then buy
# better to shuffle?
for (i,b) in enumerate(actions)
if b > 0
max_buy = div(_balance(env)[], _prices(env)[i])
buy = min(b*env.HMAX_NORMALIZE, max_buy)
_holds(env)[i] += buy
deduction = buy * _prices(env)[i]
cost = deduction * env.TRANSACTION_FEE_PERCENT
_balance(env)[] -= deduction + cost
env.total_cost += cost
end
end

env.day += 1
_prices(env) .= @view env.prices[:, env.day]
_features(env) .= @view env.features[:, env.day]

env.daily_reward = _total_asset(env) - init_asset
end

RLBase.reward(env::StockTradingEnv) = env.daily_reward * env.REWARD_SCALING
RLBase.is_terminated(env::StockTradingEnv) = env.day >= env.last_day
RLBase.state(env::StockTradingEnv) = env.state

function RLBase.reset!(env::StockTradingEnv)
env.day = env.first_day
_balance(env)[] = env.initial_account_balance
_prices(env) .= @view env.prices[:, env.day]
_features(env) .= @view env.features[:, env.day]
env.total_cost = 0.
env.daily_reward = 0.
end

RLBase.state_space(env::StockTradingEnv) = Space(fill(-Inf32..Inf32, length(state(env))))
RLBase.action_space(env::StockTradingEnv) = Space(fill(-1f0..1f0, length(_holds(env))))

RLBase.ChanceStyle(::StockTradingEnv) = DETERMINISTIC

# wrapper

struct StockTradingEnvWithTurbulence{E<:StockTradingEnv} <: AbstractEnvWrapper
env::E
turbulences::Vector{Float64}
turbulence_threshold::Float64
end

function StockTradingEnvWithTurbulence(;
turbulence_threshold=140.,
turbulences=nothing,
kw...
)
turbulences = isnothing(turbulences) && load_default_stock_data("turbulence.csv")

StockTradingEnvWithTurbulence(
StockTradingEnv(;kw...),
turbulences,
turbulence_threshold
)
end

function (w::StockTradingEnvWithTurbulence)(actions)
if w.turbulences[w.env.day] >= w.turbulence_threshold
actions .= ifelse.(actions .< 0, -Inf32, 0)
end
w.env(actions)
end
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ include("CartPoleEnv.jl")
include("MountainCarEnv.jl")
include("PendulumEnv.jl")
include("BitFlippingEnv.jl")
include("StockTradingEnv.jl")
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@testset "StockTradingEnv" begin

env = StockTradingEnvWithTurbulence()

RLBase.test_interfaces!(env)
RLBase.test_runnable!(env)
end