diff --git a/feed/_historical.py b/feed/_historical.py index 4010a8f3..dad26348 100644 --- a/feed/_historical.py +++ b/feed/_historical.py @@ -1,5 +1,4 @@ import asyncio -import bisect from typing import AsyncIterator, List from core.actors import StrategyActor @@ -83,9 +82,16 @@ def __init__( self.exchange = exchange self.ts = ts self.config_service = config_service.get("backtest") - self.buffer: List[Bar] = [] + self.queue = asyncio.Queue() + self.batch_size = self.config_service["batch_size"] async def on_receive(self, msg: StartHistoricalFeed): + producer = asyncio.create_task(self._producer(msg)) + consumer = asyncio.create_task(self._consumer()) + + await asyncio.gather(producer, consumer) + + async def _producer(self, msg: StartHistoricalFeed): symbol, timeframe = msg.symbol, msg.timeframe async with AsyncHistoricalData( @@ -94,36 +100,29 @@ async def on_receive(self, msg: StartHistoricalFeed): timeframe, msg.in_sample, msg.out_sample, - self.config_service["batch_size"], + self.batch_size, ) as stream: - async for bars in self.batched(stream, self.config_service["buff_size"]): - self._update_buffer(bars) - await self._process_buffer() + async for batch in self.batched(stream, self.batch_size): + await self.queue.put(batch) - await self._process_remaining_buffer() + await self.queue.put(None) - def _update_buffer(self, batch: List[Bar]): - for bar in batch: - bisect.insort(self.buffer, bar, key=lambda x: x.ohlcv.timestamp) - - async def _process_buffer(self): - buff_size = self.config_service["buff_size"] + async def _consumer(self): + while True: + batch = await self.queue.get() - while len(self.buffer) >= buff_size: - bars = [self.buffer.pop(0) for _ in range(buff_size)] - await self._outbox(bars) - await self._handle_market(bars) + if batch is None: + break - async def _process_remaining_buffer(self): - buff_size = self.config_service["buff_size"] + await self._process_batch(batch) + self.queue.task_done() - while self.buffer: - bars = [self.buffer.pop(0) for _ in range(min(len(self.buffer), buff_size))] - await self._outbox(bars) - await self._handle_market(bars) + async def _process_batch(self, batch: List[Bar]): + await self._outbox(batch) + await self._handle_market(batch) - async def _handle_market(self, bars: List[Bar]) -> None: - for bar in bars: + async def _handle_market(self, batch: List[Bar]) -> None: + for bar in batch: await self.tell( NewMarketDataReceived( self.symbol, self.timeframe, bar.ohlcv, bar.closed @@ -131,13 +130,14 @@ async def _handle_market(self, bars: List[Bar]) -> None: ) await asyncio.sleep(0.0001) - async def _outbox(self, bars: List[Bar]) -> None: - ts = [] - for bar in bars: - if bar.closed: - ts.append(self.ts.upsert(self.symbol, self.timeframe, bar.ohlcv)) + async def _outbox(self, batch: List[Bar]) -> None: + tasks = [ + self.ts.upsert(self.symbol, self.timeframe, bar.ohlcv) + for bar in batch + if bar.closed + ] - await asyncio.gather(*ts) + await asyncio.gather(*tasks) @staticmethod async def batched(stream: AsyncIterator[Bar], batch_size: int):