Skip to content

Commit

Permalink
pr update
Browse files Browse the repository at this point in the history
  • Loading branch information
Lin-Dongzhao committed Jul 30, 2024
1 parent b750fee commit 6acab12
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions rqalpha/mod/rqalpha_mod_sys_accounts/api/api_stock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import datetime
from decimal import Decimal, getcontext
from itertools import chain
from typing import Dict, List, Optional, Union, Tuple
from typing import Dict, List, Optional, Union, Tuple, Callable
import math
from collections import defaultdict

Expand Down Expand Up @@ -72,26 +72,14 @@ def _get_account_position_ins(id_or_ins):
return account, position, ins


def _round_order_quantity(ins, quantity, method: str = "round_down") -> int:
"""
根据合约的 round_lot 对订单数量进行取整
:param method: 取整方式,可选择 'round_down'(向下取整), 'round_up'(向上取整), 'round'(四舍五入)
"""
def _round_order_quantity(ins, quantity, method: Callable = int) -> int:
if ins.type == "CS" and ins.board_type == "KSH":
# KSH can buy(sell) 201, 202 shares
return 0 if abs(quantity) < KSH_MIN_AMOUNT else int(quantity)
else:
round_lot = ins.round_lot
try:
if method == "round_down":
return int(Decimal(quantity) / Decimal(round_lot)) * round_lot
elif method == "round_up":
return math.ceil(Decimal(quantity) / Decimal(round_lot)) * round_lot
elif method == "round":
return round(Decimal(quantity) / Decimal(round_lot)) * round_lot
else:
raise RuntimeError("Rounding method only support 'round_down', 'round_up' and 'round'")
return method(Decimal(quantity) / Decimal(round_lot)) * round_lot
except ValueError:
raise

Expand Down Expand Up @@ -387,6 +375,15 @@ def order_target_portfolio(
))

account_value = account.total_value
if total_percent == 1:
# 在此处形成的订单不包含交易费用,需要预留一点余额以供交易费用使用
estimate_transaction_cost = 0
for order_book_id, (target_percent, open_style, close_style, last_price) in target.items():
current_value = current_quantities.get(order_book_id, 0) * last_price
change_value = target_percent * account_value - current_value
estimate_transaction_cost += env.get_transaction_cost_with_value(change_value)
account_value = account_value - estimate_transaction_cost

close_orders, open_orders = [], []
waiting_to_buy = defaultdict()
for order_book_id, (target_percent, open_style, close_style, last_price) in target.items():
Expand All @@ -399,7 +396,7 @@ def order_target_portfolio(
env.order_creation_failed(order_book_id=order_book_id, reason=reason)
continue
delta_quantity = (account_value * target_percent / close_price) - current_quantities.get(order_book_id, 0)
delta_quantity = _round_order_quantity(env.data_proxy.instrument(order_book_id), delta_quantity, method="round")
delta_quantity = _round_order_quantity(env.data_proxy.instrument(order_book_id), delta_quantity, method=round)

# 优先生成卖单,以便计算出剩余现金,进行买单数量的计算
if delta_quantity == 0:
Expand All @@ -417,16 +414,19 @@ def order_target_portfolio(
estimate_cash = account.cash + sum([o.quantity * o.frozen_price - env.get_order_transaction_cost(o) for o in close_orders])
for order_book_id, (delta_quantity, position_effect, open_style, last_price) in waiting_to_buy.items():
order_price = delta_quantity * last_price
if order_price + env.get_transaction_cost_with_value(order_price) > estimate_cash:
transaction_cost = env.get_transaction_cost_with_value(order_price)
if order_price + transaction_cost > estimate_cash:
delta_quantity = estimate_cash / last_price
delta_quantity = _round_order_quantity(env.data_proxy.instrument(order_book_id), delta_quantity)
if delta_quantity == 0:
continue
order_price = delta_quantity * last_price
transaction_cost = env.get_transaction_cost_with_value(order_price)
order = Order.__from_create__(order_book_id, delta_quantity, SIDE.BUY, open_style, position_effect)
if isinstance(open_style, MarketOrder):
order.set_frozen_price(last_price)
open_orders.append(order)
estimate_cash -= order.quantity * order.frozen_price + env.get_order_transaction_cost(order)
estimate_cash -= order_price + transaction_cost

return list(env.submit_order(o) for o in chain(close_orders, open_orders))

Expand Down

0 comments on commit 6acab12

Please sign in to comment.