Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
m5l14i11 committed Sep 12, 2024
1 parent 9cce8ed commit ce6cf0e
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 44 deletions.
106 changes: 71 additions & 35 deletions exchange/_bybit_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,24 @@ def __init__(self, wss: str):
self._channels = set()
self._lock = asyncio.Lock()

async def _connect_to_websocket(self):
async def _connect(self):
if not self.ws or not self.ws.open:
self.ws = await websockets.connect(
self.wss,
open_timeout=None,
ping_interval=18,
ping_timeout=10,
close_timeout=None,
)
try:
self.ws = await websockets.connect(
self.wss,
open_timeout=2,
ping_interval=18,
ping_timeout=10,
close_timeout=3,
)

await self._wait_for_ws()
await self._subscribe()

logger.info("WebSocket connection established.")
except Exception as e:
logger.error(f"Failed to connect to WebSocket: {e}")
raise ConnectionError(e)

@retry(
max_retries=13,
Expand All @@ -69,35 +78,39 @@ async def _connect_to_websocket(self):
)
async def run(self):
await self.close()
await self._connect_to_websocket()
await self._subscribe()
await self._connect()

async def close(self):
if self.ws and self.ws.open:
await self._unsubscribe()
await self.ws.close()
await self.ws.wait_closed()

@retry(
max_retries=13,
initial_retry_delay=1,
handled_exceptions=(RuntimeError, ConnectionClosedError),
handled_exceptions=(
ConnectionError,
RuntimeError,
ConnectionClosedError,
CancelledError,
),
)
async def receive(self, symbol, timeframe):
await self._connect_to_websocket()
await self._connect()

async for message in self.ws:
data = json.loads(message)

if self.TOPIC_KEY not in data:
continue
try:
data = json.loads(message)

topic = data[self.TOPIC_KEY].split(".")
if not self._is_valid_message(symbol, timeframe, data):
continue

if symbol.name == topic[2] and timeframe == self.TIMEFRAMES.get(topic[1]):
return [
Bar(OHLCV.from_dict(ohlcv), ohlcv.get(self.CONFIRM_KEY))
for ohlcv in data.get(self.DATA_KEY, None)
if ohlcv
]
return self._parse_ohlcv(data)
except (json.JSONDecodeError, KeyError) as e:
logger.error(f"Malformed message received: {e}")
except Exception as e:
logger.exception(f"Unexpected error while receiving message: {e}")

async def subscribe(self, symbol, timeframe):
async with self._lock:
Expand All @@ -109,17 +122,16 @@ async def unsubscribe(self, symbol, timeframe):
async with self._lock:
if (symbol, timeframe) in self._channels:
self._channels.remove((symbol, timeframe))
await self._unsubscribe()
await self._subscribe()

async def _subscribe(self):
if not self.ws or not self.ws.open:
return

channels = [
f"{self.KLINE_CHANNEL}.{self.INTERVALS[timeframe]}.{symbol.name}"
for symbol, timeframe in self._channels
]
subscribe_message = {"op": self.SUBSCRIBE_OPERATION, "args": channels}
subscribe_message = {
"op": self.SUBSCRIBE_OPERATION,
"args": self._get_channels_args(),
}

try:
logger.info(f"Subscribe to: {subscribe_message}")
Expand All @@ -128,17 +140,41 @@ async def _subscribe(self):
logger.error(f"Failed to send subscribe message: {e}")

async def _unsubscribe(self):
if not self.ws or not self.ws.open:
if not self.ws or not self.ws.open or not self._channels:
return

channels = [
f"{self.KLINE_CHANNEL}.{self.INTERVALS[timeframe]}.{symbol.name}"
for symbol, timeframe in self._channels
]
unsubscribe_message = {"op": self.UNSUBSCRIBE_OPERATION, "args": channels}
unsubscribe_message = {
"op": self.UNSUBSCRIBE_OPERATION,
"args": self._get_channels_args(),
}

try:
logger.info(f"Unsubscribe from: {unsubscribe_message}")
await self.ws.send(json.dumps(unsubscribe_message))
except Exception as e:
logger.error(f"Failed to send unsubscribe message: {e}")

async def _wait_for_ws(self):
while not self.ws or not self.ws.open:
await asyncio.sleep(1.0)

def _is_valid_message(self, symbol, timeframe, data):
if self.TOPIC_KEY not in data:
return False

topic = data[self.TOPIC_KEY].split(".")

return symbol.name == topic[2] and timeframe == self.TIMEFRAMES.get(topic[1])

def _parse_ohlcv(self, data):
return [
Bar(OHLCV.from_dict(ohlcv), ohlcv.get(self.CONFIRM_KEY))
for ohlcv in data.get(self.DATA_KEY, [])
if ohlcv
]

def _get_channels_args(self):
return [
f"{self.KLINE_CHANNEL}.{self.INTERVALS[timeframe]}.{symbol.name}"
for symbol, timeframe in self._channels
]
14 changes: 5 additions & 9 deletions infrastructure/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,12 @@ def handle_retry_sync(func, *args, **kwargs):
raise max_retries_exception

def wrapper(func):
if asyncio.iscoroutinefunction(func):
async def wrapped_async(*args, **kwargs):
return await handle_retry_async(func, *args, **kwargs)

async def wrapped(*args, **kwargs):
return await handle_retry_async(func, *args, **kwargs)
def wrapped_sync(*args, **kwargs):
return handle_retry_sync(func, *args, **kwargs)

else:

def wrapped(*args, **kwargs):
return handle_retry_sync(func, *args, **kwargs)

return wrapped
return wrapped_async if asyncio.iscoroutinefunction(func) else wrapped_sync

return wrapper

0 comments on commit ce6cf0e

Please sign in to comment.