Skip to content

Commit

Permalink
pqo (#51)
Browse files Browse the repository at this point in the history
* pqo
  • Loading branch information
m5l14i11 authored Oct 11, 2024
1 parent 850231d commit e1a3f5e
Showing 1 changed file with 62 additions and 85 deletions.
147 changes: 62 additions & 85 deletions reef/_actor.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,46 @@
import asyncio
import logging
import time
from typing import Dict, Tuple
from dataclasses import dataclass, field

from coral import DataSourceFactory
from core.actors import BaseActor
from core.events.market import NewMarketOrderReceived
from core.interfaces.abstract_config import AbstractConfig
from core.models.datasource_type import DataSourceType
from core.models.entity.order import Order
from core.models.order_type import OrderStatus, OrderType
from core.models.protocol_type import ProtocolType
from core.models.symbol import Symbol

logger = logging.getLogger(__name__)


@dataclass(order=True)
class PQOrder:
order_id: str = field(compare=False)
symbol: Symbol = field(compare=False)
datasource: DataSourceType = field(compare=False)
timestamp: float = field(default_factory=lambda: time.time(), compare=True)


class ReefActor(BaseActor):
def __init__(
self, datasource_factory: DataSourceFactory, config_service: AbstractConfig
):
super().__init__()
self._lock = asyncio.Lock()
self._orders: Dict[str, Tuple[float, Symbol, DataSourceType]] = {}
self._order_queue = asyncio.PriorityQueue()
self._datasource_factory = datasource_factory
self._tasks = set()
self.order_config = config_service.get("order")

def on_start(self):
task = asyncio.create_task(self._monitor_orders())
task.add_done_callback(self._tasks.discard)
self._tasks.add(task)
worker_task = asyncio.create_task(self._process_orders())
worker_task.add_done_callback(self._tasks.discard)
self._tasks.add(worker_task)

poll_task = asyncio.create_task(self._fetch_open_orders())
poll_task.add_done_callback(self._tasks.discard)
self._tasks.add(poll_task)

def on_stop(self):
for task in list(self._tasks):
Expand All @@ -42,102 +52,69 @@ def pre_receive(self, event: NewMarketOrderReceived):
return (
isinstance(event, NewMarketOrderReceived)
and event.order.type != OrderType.PAPER
and event.order.status == OrderStatus.PENDING
)

async def on_receive(self, event: NewMarketOrderReceived):
match event.order.status:
case OrderStatus.EXECUTED:
await self._clear_order(event.order)
case OrderStatus.PENDING:
await self._append_order(event.order, event.symbol, event.datasource)
order = event.order
pq_order = PQOrder(order.id, event.symbol, event.datasource)

async def _clear_order(self, order: Order):
async with self._lock:
if order.id in self._orders:
self._orders.pop(order.id)
await self._order_queue.put(pq_order)

logging.info(f"Order {order.id} cleared.")
logging.info(f"Order {order.id} {order.status} added to the queue.")

async def _append_order(
self, order: Order, symbol: Symbol, datasource: DataSourceType
):
async with self._lock:
self._orders[order.id] = (time.time(), symbol, datasource)

logging.info(f"Order {order.id} appended for symbol {symbol.name}.")

async def _monitor_orders(self):
counter = 0
async def _process_orders(self):
expiration_time = self.order_config.get("expiration_time", 10)
monitor_interval = self.order_config.get("monitor_interval", 10)
untracked_interval = self.order_config.get("untracked_interval", 8)

try:
while True:
async with asyncio.TaskGroup() as task_group:
await asyncio.sleep(monitor_interval)
await task_group.create_task(self._cancel_expired_orders())

if counter % untracked_interval == 0:
await task_group.create_task(self._fetch_open_orders())
counter = 0
counter += 1
except asyncio.CancelledError:
logging.info("Monitoring task canceled.")
except Exception as e:
logging.error(f"Error in monitoring orders: {str(e)}")
pq_order = await self._order_queue.get()
current_time = time.time()

async def _cancel_expired_orders(self):
curr_time = time.time()
expiration_time = self.order_config.get("expiration_time", 10)
expired_orders = []
if current_time - pq_order.timestamp > expiration_time:
await self._cancel_order(
pq_order.order_id, pq_order.symbol, pq_order.datasource
)
else:
sleep_time = expiration_time - (current_time - pq_order.timestamp)

async with self._lock:
expired_orders = [
(order_id, symbol, datasource)
for order_id, (timestamp, symbol, datasource) in self._orders.items()
if curr_time - timestamp > expiration_time
]
if sleep_time > 0:
await asyncio.sleep(sleep_time)

if expired_orders:
logging.info(f"Found {len(expired_orders)} expired orders. Canceling...")
await self._cancel_order(
pq_order.order_id, pq_order.symbol, pq_order.datasource
)

for order_id, symbol, datasource in expired_orders:
service = self._datasource_factory.create(datasource, ProtocolType.REST)
service.cancel_order(order_id, symbol)
self._order_queue.task_done()

await asyncio.sleep(monitor_interval)
except asyncio.CancelledError:
logging.info("Order processing task was cancelled.")
except Exception as e:
logging.error(f"Error in processing or monitoring orders: {str(e)}")

async def _cancel_order(
self, order_id: str, symbol: Symbol, datasource: DataSourceType
):
service = self._datasource_factory.create(datasource, ProtocolType.REST)

async with self._lock:
self._orders.pop(order_id, None)
logging.info(f"Order {order_id} for symbol {symbol.name} canceled.")
await asyncio.to_thread(service.cancel_order, order_id, symbol)

logging.info(f"Order {order_id} for symbol {symbol.name} canceled.")

async def _fetch_open_orders(self):
orders = []
services = [
self._datasource_factory.create(DataSourceType.BYBIT, ProtocolType.REST),
DataSourceType.BYBIT,
]
monitor_interval = self.order_config.get("monitor_interval", 10)

for service in services:
orders += service.fetch_all_open_orders()
for datasource in services:
service = self._datasource_factory.create(datasource, ProtocolType.REST)
orders = await asyncio.to_thread(service.fetch_all_open_orders)

curr_time = time.time()
expiration_time = self.order_config.get("expiration_time", 10)
for order_id, symbol in orders:
pq_order = PQOrder(order_id, symbol, datasource)
await self._order_queue.put(pq_order)

async with self._lock:
for order_id, order_symbol in orders:
if order_id in self._orders:
timestamp, symbol, datasource = self._orders.get(order_id)

if curr_time - timestamp > expiration_time:
service = self._datasource_factory.create(
datasource, ProtocolType.REST
)
service.cancel_order(order_id, symbol)
self._orders.pop(order_id, None)

logging.info(
f"Order {order_id} for symbol {symbol.name} canceled."
)
else:
logging.warning(
f"Untracked order {order_id} found, attempting cancellation."
)
for service in services:
service.cancel_order(order_id, order_symbol)
await asyncio.sleep(monitor_interval)

0 comments on commit e1a3f5e

Please sign in to comment.