Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
m5l14i11 committed Sep 14, 2024
1 parent b129f47 commit 9f5b578
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 25 deletions.
4 changes: 3 additions & 1 deletion core/interfaces/abstract_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def fetch_position(self, symbol: Symbol, side: PositionSide):
pass

@abstractmethod
def fetch_trade(self, symbol: Symbol, side: PositionSide, limit: int):
def fetch_trade(
self, symbol: Symbol, side: PositionSide, since: int, size: float, limit: int
):
pass

@abstractmethod
Expand Down
60 changes: 39 additions & 21 deletions exchange/_bybit.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,14 @@ def cancel_order(self, order_id: str, symbol: Symbol):
logger.error(f"{symbol}: {e}")
return

def fetch_trade(self, symbol: Symbol, side: PositionSide, limit: int):
trades = sorted(
self.connector.fetch_my_trades(symbol.name, limit=limit * 3),
key=lambda trade: trade["timestamp"],
reverse=True,
)
def fetch_trade(
self, symbol: Symbol, side: PositionSide, since: int, size: float, limit: int
):
last_trades = [
trade
for trade in self.connector.fetch_my_trades(symbol.name, limit=limit * 2)
if trade["timestamp"] >= since
]

def round_down_to_minute(timestamp):
return datetime.utcfromtimestamp(timestamp // 1000).replace(
Expand All @@ -109,21 +111,36 @@ def round_down_to_minute(timestamp):

aggregated_trades = defaultdict(lambda: {"amount": 0, "price": 0, "fee": 0})

for trade in trades:
if trade["side"] == "buy" if side == PositionSide.SHORT else "sell":
timestamp = round_down_to_minute(trade["timestamp"])
aggregated_trades[timestamp]["amount"] += trade["amount"]
aggregated_trades[timestamp]["price"] += trade["price"]
aggregated_trades[timestamp]["fee"] += trade["fee"]["cost"]

for timestamp, trade_data in aggregated_trades.items():
count = sum(
1
for item in trades
if round_down_to_minute(item["timestamp"]) == timestamp
)
if count > 0:
trade_data["price"] /= count
opposite_side = "buy" if side == PositionSide.SHORT else "sell"

trade_stack = sorted(
[trade for trade in last_trades if trade["side"] == opposite_side],
key=lambda trade: trade["timestamp"],
reverse=True,
)

acc_size = 0

for trade in trade_stack:
if acc_size >= size:
break

trade_amount = trade["amount"]
key = round_down_to_minute(trade["timestamp"])

trade_amount = min(trade["amount"], size - acc_size)

aggregated_trades[key]["amount"] += trade_amount
aggregated_trades[key]["price"] += trade["price"] * trade_amount
aggregated_trades[key]["fee"] += (
trade["fee"]["cost"] / trade["amount"]
) * trade_amount

acc_size += trade_amount

for _, trade_data in aggregated_trades.items():
if trade_data["amount"] > 0:
trade_data["price"] /= trade_data["amount"]

return next(iter(aggregated_trades.values()), None)

Expand Down Expand Up @@ -234,6 +251,7 @@ def fetch_position(self, symbol: Symbol, side: PositionSide):
else PositionSide.SHORT,
"entry_price": float(position.get("entryPrice", 0)),
"position_size": float(position.get("contracts", 0)),
"open_fee": -float(position.get("info", {}).get("curRealisedPnl", 0)),
}

return None
Expand Down
13 changes: 10 additions & 3 deletions sor/_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_open_position(self, query: GetOpenPosition):
status=OrderStatus.EXECUTED,
size=broker_position["position_size"],
price=broker_position["entry_price"],
fee=broker_position["open_fee"],
)

@query_handler(GetClosePosition)
Expand All @@ -57,7 +58,11 @@ def get_close_position(self, query: GetClosePosition):
symbol = position.signal.symbol

trade = self.exchange.fetch_trade(
symbol, position.side, self.config["max_order_slice"]
symbol,
position.side,
position.last_modified,
position.size,
self.config["max_order_slice"],
)

if not trade:
Expand Down Expand Up @@ -88,13 +93,15 @@ def get_account_balance(self, query: GetBalance):
def has_position(self, query: HasPosition):
position = query.position
symbol = position.signal.symbol
side = position.side

existing_position = self.exchange.fetch_position(symbol, position.side)
existing_position = self.exchange.fetch_position(symbol, side)

if existing_position:
logging.info(f"Position for {symbol} on {position.side} already exists")
logging.info(f"Position check: {side} position for {symbol} exists.")
return True

logging.info(f"Position check: No existing {side} position found for {symbol}.")
return False

@command_handler(UpdateSettings)
Expand Down

0 comments on commit 9f5b578

Please sign in to comment.