diff --git a/risk/_actor.py b/risk/_actor.py index 3016f15f..f3426592 100644 --- a/risk/_actor.py +++ b/risk/_actor.py @@ -1,6 +1,6 @@ import asyncio from collections import deque -from typing import Optional, Union +from typing import List, Optional, Union from core.actors import Actor from core.events.ohlcv import NewMarketDataReceived @@ -118,14 +118,23 @@ async def _close_position(self, event: PositionClosed): async def _handle_market_risk(self, event: NewMarketDataReceived): async with self.lock: self._ohlcv.append(event.ohlcv) + visited = set() + ohlcvs = [] + + for i in range(len(self._ohlcv)): + if self._ohlcv[i].timestamp not in visited: + ohlcvs.append(self._ohlcv[i]) + visited.add(self._ohlcv[i].timestamp) + + ohlcvs = sorted(ohlcvs, key=lambda x: x.timestamp) long_position, short_position = self._position if long_position or short_position: long_position, short_position = await asyncio.gather( *[ - self._process_position(long_position), - self._process_position(short_position), + self._process_position(long_position, ohlcvs), + self._process_position(short_position, ohlcvs), ] ) @@ -167,15 +176,14 @@ async def _handle_signal_exit( event.exit_price, ) - async def _process_position(self, position: Optional[Position]): - ohlcvs = list(self._ohlcv) + async def _process_position( + self, position: Optional[Position], ohlcvs: List[OHLCV] + ): next_position = position if position and len(ohlcvs) > 1: next_position = position.next(ohlcvs) - last_candle = ohlcvs[-1] - - exit_event = self._create_exit_event(next_position, last_candle) + exit_event = self._create_exit_event(next_position, ohlcvs[-1]) if exit_event: await self.tell(exit_event)