From 9f5b578eb6f59d68a4895457879ef891fd9ddeff Mon Sep 17 00:00:00 2001 From: m5l14i11 Date: Sat, 14 Sep 2024 17:51:40 +0300 Subject: [PATCH] upd --- core/interfaces/abstract_exchange.py | 4 +- exchange/_bybit.py | 60 ++++++++++++++++++---------- sor/_router.py | 13 ++++-- 3 files changed, 52 insertions(+), 25 deletions(-) diff --git a/core/interfaces/abstract_exchange.py b/core/interfaces/abstract_exchange.py index 93185644..93761033 100644 --- a/core/interfaces/abstract_exchange.py +++ b/core/interfaces/abstract_exchange.py @@ -42,7 +42,9 @@ def fetch_position(self, symbol: Symbol, side: PositionSide): pass @abstractmethod - def fetch_trade(self, symbol: Symbol, side: PositionSide, limit: int): + def fetch_trade( + self, symbol: Symbol, side: PositionSide, since: int, size: float, limit: int + ): pass @abstractmethod diff --git a/exchange/_bybit.py b/exchange/_bybit.py index 51e0f3b9..b6d3e9a6 100644 --- a/exchange/_bybit.py +++ b/exchange/_bybit.py @@ -95,12 +95,14 @@ def cancel_order(self, order_id: str, symbol: Symbol): logger.error(f"{symbol}: {e}") return - def fetch_trade(self, symbol: Symbol, side: PositionSide, limit: int): - trades = sorted( - self.connector.fetch_my_trades(symbol.name, limit=limit * 3), - key=lambda trade: trade["timestamp"], - reverse=True, - ) + def fetch_trade( + self, symbol: Symbol, side: PositionSide, since: int, size: float, limit: int + ): + last_trades = [ + trade + for trade in self.connector.fetch_my_trades(symbol.name, limit=limit * 2) + if trade["timestamp"] >= since + ] def round_down_to_minute(timestamp): return datetime.utcfromtimestamp(timestamp // 1000).replace( @@ -109,21 +111,36 @@ def round_down_to_minute(timestamp): aggregated_trades = defaultdict(lambda: {"amount": 0, "price": 0, "fee": 0}) - for trade in trades: - if trade["side"] == "buy" if side == PositionSide.SHORT else "sell": - timestamp = round_down_to_minute(trade["timestamp"]) - aggregated_trades[timestamp]["amount"] += trade["amount"] - aggregated_trades[timestamp]["price"] += trade["price"] - aggregated_trades[timestamp]["fee"] += trade["fee"]["cost"] - - for timestamp, trade_data in aggregated_trades.items(): - count = sum( - 1 - for item in trades - if round_down_to_minute(item["timestamp"]) == timestamp - ) - if count > 0: - trade_data["price"] /= count + opposite_side = "buy" if side == PositionSide.SHORT else "sell" + + trade_stack = sorted( + [trade for trade in last_trades if trade["side"] == opposite_side], + key=lambda trade: trade["timestamp"], + reverse=True, + ) + + acc_size = 0 + + for trade in trade_stack: + if acc_size >= size: + break + + trade_amount = trade["amount"] + key = round_down_to_minute(trade["timestamp"]) + + trade_amount = min(trade["amount"], size - acc_size) + + aggregated_trades[key]["amount"] += trade_amount + aggregated_trades[key]["price"] += trade["price"] * trade_amount + aggregated_trades[key]["fee"] += ( + trade["fee"]["cost"] / trade["amount"] + ) * trade_amount + + acc_size += trade_amount + + for _, trade_data in aggregated_trades.items(): + if trade_data["amount"] > 0: + trade_data["price"] /= trade_data["amount"] return next(iter(aggregated_trades.values()), None) @@ -234,6 +251,7 @@ def fetch_position(self, symbol: Symbol, side: PositionSide): else PositionSide.SHORT, "entry_price": float(position.get("entryPrice", 0)), "position_size": float(position.get("contracts", 0)), + "open_fee": -float(position.get("info", {}).get("curRealisedPnl", 0)), } return None diff --git a/sor/_router.py b/sor/_router.py index 8d468d79..57acfa41 100644 --- a/sor/_router.py +++ b/sor/_router.py @@ -49,6 +49,7 @@ def get_open_position(self, query: GetOpenPosition): status=OrderStatus.EXECUTED, size=broker_position["position_size"], price=broker_position["entry_price"], + fee=broker_position["open_fee"], ) @query_handler(GetClosePosition) @@ -57,7 +58,11 @@ def get_close_position(self, query: GetClosePosition): symbol = position.signal.symbol trade = self.exchange.fetch_trade( - symbol, position.side, self.config["max_order_slice"] + symbol, + position.side, + position.last_modified, + position.size, + self.config["max_order_slice"], ) if not trade: @@ -88,13 +93,15 @@ def get_account_balance(self, query: GetBalance): def has_position(self, query: HasPosition): position = query.position symbol = position.signal.symbol + side = position.side - existing_position = self.exchange.fetch_position(symbol, position.side) + existing_position = self.exchange.fetch_position(symbol, side) if existing_position: - logging.info(f"Position for {symbol} on {position.side} already exists") + logging.info(f"Position check: {side} position for {symbol} exists.") return True + logging.info(f"Position check: No existing {side} position found for {symbol}.") return False @command_handler(UpdateSettings)