Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
m5l14i11 committed Oct 10, 2024
1 parent 850231d commit 83ea998
Showing 1 changed file with 91 additions and 85 deletions.
176 changes: 91 additions & 85 deletions reef/_actor.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,48 @@
import asyncio
import heapq
import logging
import time
from typing import Dict, Tuple
from dataclasses import dataclass

from coral import DataSourceFactory
from core.actors import BaseActor
from core.events.market import NewMarketOrderReceived
from core.interfaces.abstract_config import AbstractConfig
from core.models.datasource_type import DataSourceType
from core.models.entity.order import Order
from core.models.order_type import OrderStatus, OrderType
from core.models.order_type import OrderType
from core.models.protocol_type import ProtocolType
from core.models.symbol import Symbol

logger = logging.getLogger(__name__)


@dataclass(order=True)
class PQOrder:
timestamp: float
order_id: str
symbol: Symbol
datasource: DataSourceType


class ReefActor(BaseActor):
def __init__(
self, datasource_factory: DataSourceFactory, config_service: AbstractConfig
):
super().__init__()
self._lock = asyncio.Lock()
self._orders: Dict[str, Tuple[float, Symbol, DataSourceType]] = {}
self._orders = []
self._datasource_factory = datasource_factory
self._tasks = set()
self.order_config = config_service.get("order")

def on_start(self):
task = asyncio.create_task(self._monitor_orders())
task.add_done_callback(self._tasks.discard)
self._tasks.add(task)
monitor_task = asyncio.create_task(self._monitor_orders())
monitor_task.add_done_callback(self._tasks.discard)
self._tasks.add(monitor_task)

worker_task = asyncio.create_task(self._process_order_queue())
worker_task.add_done_callback(self._tasks.discard)
self._tasks.add(worker_task)

def on_stop(self):
for task in list(self._tasks):
Expand All @@ -45,99 +57,93 @@ def pre_receive(self, event: NewMarketOrderReceived):
)

async def on_receive(self, event: NewMarketOrderReceived):
match event.order.status:
case OrderStatus.EXECUTED:
await self._clear_order(event.order)
case OrderStatus.PENDING:
await self._append_order(event.order, event.symbol, event.datasource)
order = event.order

async def _clear_order(self, order: Order):
async with self._lock:
if order.id in self._orders:
self._orders.pop(order.id)

logging.info(f"Order {order.id} cleared.")

async def _append_order(
self, order: Order, symbol: Symbol, datasource: DataSourceType
heapq.heappush(
self._orders,
PQOrder(time.time(), order.id, event.symbol, event.datasource),
)
logging.info(f"Order {order.id} {order.status} added to the queue.")

async def _process_order_queue(self):
while True:
try:
async with self._lock:
if not self._orders:
await asyncio.sleep(1)
continue

next_order = self._orders[0]
current_time = time.time()

if next_order.timestamp > current_time:
await asyncio.sleep(next_order.timestamp - current_time)
continue

prioritized_order = heapq.heappop(self._orders)

await self._cancel_order(
prioritized_order.order_id,
prioritized_order.symbol,
prioritized_order.datasource,
)

except Exception as e:
logging.error(f"Error processing order queue: {str(e)}")

async def _cancel_order(
self, order_id: str, symbol: Symbol, datasource: DataSourceType
):
async with self._lock:
self._orders[order.id] = (time.time(), symbol, datasource)
service = self._datasource_factory.create(datasource, ProtocolType.REST)

logging.info(f"Order {order.id} appended for symbol {symbol.name}.")
await asyncio.to_thread(service.cancel_order, order_id, symbol)

logging.info(f"Order {order_id} for symbol {symbol.name} canceled.")

async def _monitor_orders(self):
counter = 0
expiration_time = self.order_config.get("expiration_time", 10)
monitor_interval = self.order_config.get("monitor_interval", 10)
untracked_interval = self.order_config.get("untracked_interval", 8)

try:
while True:
async with asyncio.TaskGroup() as task_group:
await asyncio.sleep(monitor_interval)
await task_group.create_task(self._cancel_expired_orders())

if counter % untracked_interval == 0:
await task_group.create_task(self._fetch_open_orders())
counter = 0
counter += 1
except asyncio.CancelledError:
logging.info("Monitoring task canceled.")
except Exception as e:
logging.error(f"Error in monitoring orders: {str(e)}")

async def _cancel_expired_orders(self):
curr_time = time.time()
expiration_time = self.order_config.get("expiration_time", 10)
expired_orders = []

async with self._lock:
expired_orders = [
(order_id, symbol, datasource)
for order_id, (timestamp, symbol, datasource) in self._orders.items()
if curr_time - timestamp > expiration_time
]

if expired_orders:
logging.info(f"Found {len(expired_orders)} expired orders. Canceling...")
await asyncio.sleep(monitor_interval)

async with self._lock:
current_time = time.time()
expired_orders = [
order
for order in self._orders
if current_time - order.timestamp > expiration_time
]

for expired_order in expired_orders:
await self._cancel_order(
expired_order.order_id,
expired_order.symbol,
expired_order.datasource,
)

for order_id, symbol, datasource in expired_orders:
service = self._datasource_factory.create(datasource, ProtocolType.REST)
service.cancel_order(order_id, symbol)
self._orders = [
order for order in self._orders if order not in expired_orders
]

async with self._lock:
self._orders.pop(order_id, None)
logging.info(f"Order {order_id} for symbol {symbol.name} canceled.")
except asyncio.CancelledError:
logging.info("Monitoring task canceled.")
except Exception:
logging.e

async def _fetch_open_orders(self):
orders = []
services = [
self._datasource_factory.create(DataSourceType.BYBIT, ProtocolType.REST),
DataSourceType.BYBIT,
]

for service in services:
orders += service.fetch_all_open_orders()

curr_time = time.time()
expiration_time = self.order_config.get("expiration_time", 10)

async with self._lock:
for order_id, order_symbol in orders:
if order_id in self._orders:
timestamp, symbol, datasource = self._orders.get(order_id)

if curr_time - timestamp > expiration_time:
service = self._datasource_factory.create(
datasource, ProtocolType.REST
)
service.cancel_order(order_id, symbol)
self._orders.pop(order_id, None)
for datasource in services:
service = self._datasource_factory.create(datasource, ProtocolType.REST)
orders = await asyncio.to_thread(service.fetch_all_open_orders)

logging.info(
f"Order {order_id} for symbol {symbol.name} canceled."
)
else:
logging.warning(
f"Untracked order {order_id} found, attempting cancellation."
async with self._lock:
for order_id, symbol in orders:
heapq.heappush(
self._orders, PQOrder(time.time(), order_id, symbol, datasource)
)
for service in services:
service.cancel_order(order_id, order_symbol)

0 comments on commit 83ea998

Please sign in to comment.