Skip to content

Commit

Permalink
Fix datetime formatting and support configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
MDUYN committed Apr 4, 2024
1 parent 3da38ac commit 097e251
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 45 deletions.
2 changes: 1 addition & 1 deletion examples/backtest/algorithm/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def apply_strategy(self, algorithm: Algorithm, market_data):
fast = ti.sma(df['Close'].to_numpy(), self.fast)
slow = ti.sma(df['Close'].to_numpy(), self.slow)
trend = ti.sma(df['Close'].to_numpy(), self.trend)
price = ticker_data['bid']
price = ticker_data["bid"]

if not algorithm.has_position(target_symbol) \
and is_crossover(fast, slow) \
Expand Down
5 changes: 5 additions & 0 deletions investing_algorithm_framework/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def _initialize_app_for_backtest(
market_data_source.to_backtest_market_data_source()
for market_data_source in market_data_sources
]

for market_data_source in backtest_market_data_sources:
market_data_source.config = self.config

self.container.market_data_source_service.override(
BacktestMarketDataSourceService(
market_data_sources=backtest_market_data_sources,
Expand Down Expand Up @@ -787,6 +791,7 @@ def run_backtests(
return reports

def add_market_data_source(self, market_data_source):
market_data_source.config = self.config
self._market_data_source_service.add(market_data_source)

def add_market_credential(self, market_credential: MarketCredential):
Expand Down
2 changes: 2 additions & 0 deletions investing_algorithm_framework/dependency_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class DependencyContainer(containers.DeclarativeContainer):
)
order_service = providers.Factory(
OrderService,
configuration_service=configuration_service,
order_repository=order_repository,
order_fee_repository=order_fee_repository,
portfolio_repository=portfolio_repository,
Expand Down Expand Up @@ -125,6 +126,7 @@ class DependencyContainer(containers.DeclarativeContainer):
)
backtest_service = providers.Factory(
BacktestService,
configuration_service=configuration_service,
order_service=order_service,
portfolio_repository=portfolio_repository,
performance_service=performance_service,
Expand Down
1 change: 1 addition & 0 deletions investing_algorithm_framework/domain/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class Config(dict):
SQLITE_INITIALIZED = False
BACKTEST_DATA_DIRECTORY_NAME = "backtest_data"
SYMBOLS = None
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"

def __init__(self, resource_directory=None):
super().__init__()
Expand Down
6 changes: 1 addition & 5 deletions investing_algorithm_framework/domain/models/order/order.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

from dateutil.parser import parse
from dateutil.tz import gettz

from investing_algorithm_framework.domain.exceptions import \
OperationalException
Expand Down Expand Up @@ -308,10 +307,7 @@ def from_ccxt_order(ccxt_order):
remaining=ccxt_order.get("remaining", None),
cost=ccxt_order.get("cost", None),
fee=OrderFee.from_ccxt_fee(ccxt_order.get("fee", None)),
created_at=parse(
ccxt_order.get("datetime", None),
tzinfos={"UTC": gettz("UTC")}
)
created_at=parse(ccxt_order.get("datetime", None))
)

def __repr__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def __init__(
self._backtest_data_start_date = backtest_data_start_date
self._backtest_data_index_date = backtest_data_index_date

@property
def config(self):
return self._config

@config.setter
def config(self, value):
self._config = value

def _data_source_exists(self, file_path):
"""
Function to check if the data source exists.
Expand Down Expand Up @@ -61,8 +69,6 @@ def _data_source_exists(self, file_path):

return True
except Exception as e:
logger.error(f"Error reading {file_path}")
logger.error(e)
return False

def write_data_to_file_path(self, data_file, data):
Expand Down Expand Up @@ -151,10 +157,10 @@ def backtest_data_index_date(self, value):
class MarketDataSource(ABC):

def __init__(
self,
identifier,
market,
symbol,
self,
identifier,
market,
symbol,
):
self._identifier = identifier
self._market = market
Expand All @@ -168,6 +174,14 @@ def initialize(self, config):
def identifier(self):
return self._identifier

@property
def config(self):
return self._config

@config.setter
def config(self, value):
self._config = value

def get_identifier(self):
return self.identifier

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import logging
import os
from datetime import timedelta

import polars
from dateutil import parser

from investing_algorithm_framework.domain import RESOURCE_DIRECTORY, \
BACKTEST_DATA_DIRECTORY_NAME, DATETIME_FORMAT_BACKTESTING, \
Expand Down Expand Up @@ -59,7 +61,7 @@ def __init__(
start_date_func=start_date_func,
end_date=end_date,
end_date_func=end_date_func,
window_size=window_size
window_size=window_size,
)

def prepare_data(
Expand All @@ -80,7 +82,6 @@ def prepare_data(
When downloading the data it will use the ccxt library.
"""
# Calculating the backtest data start date

difference = self.end_date - self.start_date
total_minutes = 0

Expand Down Expand Up @@ -127,7 +128,10 @@ def prepare_data(
)

# Get the OHLCV data from the ccxt market service
market_service = CCXTMarketService(self.market_credential_service)
market_service = CCXTMarketService(
market_credential_service=self.market_credential_service,
)
market_service.config = config
ohlcv = market_service.get_ohlcv(
symbol=self.symbol,
time_frame=self.timeframe,
Expand Down Expand Up @@ -171,6 +175,7 @@ def get_data(self, backtest_index_date, **kwargs):
from_timestamp = backtest_index_date - timedelta(
minutes=self.total_minutes_timeframe
)
datetime_format = self._config["DATETIME_FORMAT"]
self.backtest_data_index_date = backtest_index_date\
.replace(microsecond=0)
from_timestamp = from_timestamp.replace(microsecond=0)
Expand All @@ -193,8 +198,8 @@ def get_data(self, backtest_index_date, **kwargs):
file_path, columns=self.column_names, separator=","
)
df = df.filter(
(df['Datetime'] >= from_timestamp.strftime(DATETIME_FORMAT))
& (df['Datetime'] <= to_timestamp.strftime(DATETIME_FORMAT))
(df['Datetime'] >= from_timestamp.strftime(datetime_format))
& (df['Datetime'] <= to_timestamp.strftime(datetime_format))
)
return df

Expand All @@ -209,6 +214,9 @@ def empty(self):
def file_name(self):
return self._create_file_path().split("/")[-1]

def write_data_to_file_path(self, data_file, data: polars.DataFrame):
data.write_csv(data_file)


class CCXTTickerBacktestMarketDataSource(
TickerMarketDataSource, BacktestMarketDataSource
Expand Down Expand Up @@ -304,7 +312,10 @@ def prepare_data(
)

# Get the OHLCV data from the ccxt market service
market_service = CCXTMarketService(self.market_credential_service)
market_service = CCXTMarketService(
market_credential_service=self.market_credential_service
)
market_service.config = config
ohlcv = market_service.get_ohlcv(
symbol=self.symbol,
time_frame=self.timeframe,
Expand Down Expand Up @@ -363,13 +374,12 @@ def get_data(self, **kwargs):
# Filter the data based on the backtest index date and the end date
df = polars.read_csv(file_path)
df = df.filter(
(df['Datetime'] >= backtest_index_date
.strftime(DATETIME_FORMAT))
(df['Datetime'] >= backtest_index_date.strftime(DATETIME_FORMAT))
)

first_row = df.head(1)[0]
first_row_datetime = parser.parse(first_row["Datetime"][0])

if first_row["Datetime"][0] > end_date.strftime(DATETIME_FORMAT):
if first_row_datetime > end_date:
logger.warning(
f"No ticker data available for the given backtest "
f"index date {backtest_index_date} and symbol {self.symbol} "
Expand All @@ -386,12 +396,17 @@ def get_data(self, **kwargs):
"datetime": first_row["Datetime"][0],
}

def write_data_to_file_path(self, data_file, data: polars.DataFrame):
data.write_csv(data_file)


class CCXTOHLCVMarketDataSource(OHLCVMarketDataSource):

def get_data(self, **kwargs):
market_service = CCXTMarketService(self.market_credential_service)

market_service = CCXTMarketService(
market_credential_service=self.market_credential_service,
)
market_service.config = self.config
if self.start_date is None:
raise OperationalException(
"Either start_date or start_date_func should be set "
Expand Down Expand Up @@ -422,7 +437,10 @@ def to_backtest_market_data_source(self) -> BacktestMarketDataSource:
class CCXTOrderBookMarketDataSource(OrderBookMarketDataSource):

def get_data(self, **kwargs):
market_service = CCXTMarketService(self.market_credential_service)
market_service = CCXTMarketService(
market_credential_service=self.market_credential_service
)
market_service.config = self.config
return market_service.get_order_book(
symbol=self.symbol, market=self.market
)
Expand All @@ -438,7 +456,8 @@ def __init__(
identifier,
market,
symbol=None,
backtest_timeframe=None
backtest_timeframe=None,

):
super().__init__(
identifier=identifier,
Expand All @@ -448,7 +467,10 @@ def __init__(
self._backtest_timeframe = backtest_timeframe

def get_data(self, **kwargs):
market_service = CCXTMarketService(self.market_credential_service)
market_service = CCXTMarketService(
market_credential_service=self.market_credential_service
)
market_service.config = self.config

if self.market is None:

Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
from typing import Dict
from datetime import datetime
from time import sleep
import polars as pl
from typing import Dict

import ccxt
import polars as pl
from dateutil import parser
from dateutil.tz import gettz

from investing_algorithm_framework.domain import OperationalException, Order, \
CCXT_DATETIME_FORMAT, MarketService
MarketService

logger = logging.getLogger(__name__)

Expand All @@ -20,6 +20,19 @@ class CCXTMarketService(MarketService):
msec = 1000
minute = 60 * msec

def __init__(self, market_credential_service):
super(CCXTMarketService, self).__init__(
market_credential_service=market_credential_service,
)

@property
def config(self):
return self._config

@config.setter
def config(self, config):
self._config = config

def initialize_exchange(self, market, market_credential):
market = market.lower()
if not hasattr(ccxt, market):
Expand Down Expand Up @@ -149,6 +162,7 @@ def get_order(self, order, market):
def get_orders(self, symbol, market, since: datetime = None):
market_credential = self.get_market_credential(market)
exchange = self.initialize_exchange(market, market_credential)
datetime_format = self.config["DATETIME_FORMAT"]

if not exchange.has['fetchOrders']:
raise OperationalException(
Expand All @@ -157,7 +171,7 @@ def get_orders(self, symbol, market, since: datetime = None):
)

if since is not None:
since = exchange.parse8601(since.strftime(":%Y-%m-%d %H:%M:%S"))
since = exchange.parse8601(datetime_format)

try:
ccxt_orders = exchange.fetchOrders(symbol, since=since)
Expand Down Expand Up @@ -343,6 +357,7 @@ def get_closed_orders(
def get_ohlcv(
self, symbol, time_frame, from_timestamp, market, to_timestamp=None
) -> pl.DataFrame:
datetime_format = self.config["DATETIME_FORMAT"]
market_credential = self.get_market_credential(market)
exchange = self.initialize_exchange(market, market_credential)

Expand All @@ -353,14 +368,14 @@ def get_ohlcv(
)

from_time_stamp = exchange.parse8601(
from_timestamp.strftime(CCXT_DATETIME_FORMAT)
from_timestamp.strftime(datetime_format)
)

if to_timestamp is None:
to_timestamp = exchange.milliseconds()
else:
to_timestamp = exchange.parse8601(
to_timestamp.strftime(CCXT_DATETIME_FORMAT)
to_timestamp.strftime(datetime_format)
)
data = []

Expand All @@ -374,17 +389,16 @@ def get_ohlcv(
from_time_stamp = to_timestamp

for candle in ohlcv:
datetime_stamp = parser.parse(
exchange.iso8601(candle[0]),
tzinfos={"UTC": gettz("UTC")}
datetime_stamp = parser.parse(exchange.iso8601(candle[0]))

)
to_timestamp_datetime = parser.parse(
exchange.iso8601(to_timestamp),
tzinfos={"UTC": gettz("UTC")}
)

if datetime_stamp <= to_timestamp_datetime:
datetime_stamp = datetime_stamp\
.strftime(datetime_format)

data.append([datetime_stamp] + candle[1:])

sleep(exchange.rateLimit / 1000)
Expand Down
Loading

0 comments on commit 097e251

Please sign in to comment.