-
Notifications
You must be signed in to change notification settings - Fork 482
/
Copy pathforex_env.py
106 lines (80 loc) · 3.89 KB
/
forex_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
from .trading_env import TradingEnv, Actions, Positions
class ForexEnv(TradingEnv):
def __init__(self, df, window_size, frame_bound, unit_side='left', render_mode=None):
assert len(frame_bound) == 2
assert unit_side.lower() in ['left', 'right']
self.frame_bound = frame_bound
self.unit_side = unit_side.lower()
super().__init__(df, window_size, render_mode)
self.trade_fee = 0.0003 # unit
def _process_data(self):
prices = self.df.loc[:, 'Close'].to_numpy()
prices[self.frame_bound[0] - self.window_size] # validate index (TODO: Improve validation)
prices = prices[self.frame_bound[0]-self.window_size:self.frame_bound[1]]
diff = np.insert(np.diff(prices), 0, 0)
signal_features = np.column_stack((prices, diff))
return prices.astype(np.float32), signal_features.astype(np.float32)
def _calculate_reward(self, action):
step_reward = 0 # pip
trade = False
if (
(action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)
):
trade = True
if trade:
current_price = self.prices[self._current_tick]
last_trade_price = self.prices[self._last_trade_tick]
price_diff = current_price - last_trade_price
if self._position == Positions.Short:
step_reward += -price_diff * 10000
elif self._position == Positions.Long:
step_reward += price_diff * 10000
return step_reward
def _update_profit(self, action):
trade = False
if (
(action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)
):
trade = True
if trade or self._truncated:
current_price = self.prices[self._current_tick]
last_trade_price = self.prices[self._last_trade_tick]
if self.unit_side == 'left':
if self._position == Positions.Short:
quantity = self._total_profit * (last_trade_price - self.trade_fee)
self._total_profit = quantity / current_price
elif self.unit_side == 'right':
if self._position == Positions.Long:
quantity = self._total_profit / last_trade_price
self._total_profit = quantity * (current_price - self.trade_fee)
def max_possible_profit(self):
current_tick = self._start_tick
last_trade_tick = current_tick - 1
profit = 1.
while current_tick <= self._end_tick:
position = None
if self.prices[current_tick] < self.prices[current_tick - 1]:
while (current_tick <= self._end_tick and
self.prices[current_tick] < self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Short
else:
while (current_tick <= self._end_tick and
self.prices[current_tick] >= self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Long
current_price = self.prices[current_tick - 1]
last_trade_price = self.prices[last_trade_tick]
if self.unit_side == 'left':
if position == Positions.Short:
quantity = profit * (last_trade_price - self.trade_fee)
profit = quantity / current_price
elif self.unit_side == 'right':
if position == Positions.Long:
quantity = profit / last_trade_price
profit = quantity * (current_price - self.trade_fee)
last_trade_tick = current_tick - 1
return profit