Skip to content

Commit

Permalink
Refactor data source registration
Browse files Browse the repository at this point in the history
  • Loading branch information
MDUYN committed Jul 1, 2024
1 parent 07f69f2 commit bb7d5f6
Show file tree
Hide file tree
Showing 15 changed files with 961 additions and 111 deletions.
9 changes: 7 additions & 2 deletions investing_algorithm_framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
TickerMarketDataSource, MarketService, BacktestReportsEvaluation, \
pretty_print_backtest_reports_evaluation, load_backtest_reports, \
RESERVED_BALANCES, APP_MODE, AppMode, DATETIME_FORMAT, \
load_backtest_report, BacktestDateRange
load_backtest_report, BacktestDateRange, create_ema_graph, \
create_prices_graph, create_rsi_graph, get_price_efficiency_ratio
from investing_algorithm_framework.infrastructure import \
CCXTOrderBookMarketDataSource, CCXTOHLCVMarketDataSource, \
CCXTTickerMarketDataSource, CSVOHLCVMarketDataSource, \
Expand Down Expand Up @@ -67,5 +68,9 @@
"load_backtest_report",
"BacktestDateRange",
"create_trade_exit_markers_chart",
"create_trade_entry_markers_chart"
"create_trade_entry_markers_chart",
"create_ema_graph",
"create_prices_graph",
"create_rsi_graph",
"get_price_efficiency_ratio"
]
124 changes: 85 additions & 39 deletions investing_algorithm_framework/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,12 @@ def _initialize_app_for_backtest(
before running a backtest or a set of backtests and should be called
once.
:param backtest_date_range: instance of BacktestDateRange
:param pending_order_check_interval: The interval at which to check
pending orders (e.g. 1h, 1d, 1w)
:return: None
Args:
backtest_date_range: instance of BacktestDateRange
pending_order_check_interval: The interval at which to check
pending orders (e.g. 1h, 1d, 1w)
Return None
"""
# Set all config vars for backtesting
configuration_service = self.container.configuration_service()
Expand All @@ -275,7 +277,18 @@ def _initialize_app_for_backtest(
# Create resource dir if not exits
self._create_resource_directory_if_not_exists()

def _initialize_algorithm_for_backtest(self, algorithm):
def _create_backtest_database_if_not_exists(self):
"""
Create the backtest database if it does not exist. This method
should be called before running a backtest for an algorithm.
It creates the database if it does not exist.
Args:
None
Returns
None
"""
configuration_service = self.container.configuration_service()
resource_dir = configuration_service.config[RESOURCE_DIRECTORY]

Expand All @@ -301,15 +314,27 @@ def _initialize_algorithm_for_backtest(self, algorithm):
setup_sqlalchemy(self)
create_all_tables()

# Override the MarketDataSourceService service with the backtest
# market data source service equivalent. Additionally, convert the
# market data sources to backtest market data sources
# Get all market data source services
market_data_sources = self._market_data_source_service\
def _initialize_backtest_data_sources(self, algorithm):
"""
Initialize the backtest data sources for the algorithm. This method
should be called before running a backtest. It initializes the
backtest data sources for the algorithm. It takes all registered
data sources and converts them to backtest equivalents
Args:
algorithm: The algorithm to initialize for backtesting
Returns
None
"""

market_data_sources = self._market_data_source_service \
.get_market_data_sources()
backtest_market_data_sources = []

if algorithm.data_sources is not None \
and len(algorithm.data_sources) > 0:

for data_source in algorithm.data_sources:
self.add_market_data_source(data_source)

Expand All @@ -324,16 +349,36 @@ def _initialize_algorithm_for_backtest(self, algorithm):
if market_data_source is not None:
market_data_source.config = self.config

self.container.market_data_source_service.override(
BacktestMarketDataSourceService(
market_data_sources=backtest_market_data_sources,
market_service=self.container.market_service(),
market_credential_service=self.container
.market_credential_service(),
configuration_service=self.container
.configuration_service(),
)
# Override the market data source service with the backtest market
# data source service
self.container.market_data_source_service.override(
BacktestMarketDataSourceService(
market_data_sources=backtest_market_data_sources,
market_service=self.container.market_service(),
market_credential_service=self.container
.market_credential_service(),
configuration_service=self.container
.configuration_service(),
)
)

# Set all data sources to the algorithm
algorithm.add_data_sources(backtest_market_data_sources)

def _initialize_algorithm_for_backtest(self, algorithm):
"""
Function to initialize the algorithm for backtesting. This method
should be called before running a backtest. It initializes the
all data sources to backtest data sources and overrides the services
with the backtest services equivalents.
Args:
algorithm: The algorithm to initialize for backtesting
Return None
"""
self._create_backtest_database_if_not_exists()
self._initialize_backtest_data_sources(algorithm)

# Override the portfolio service with the backtest portfolio service
self.container.portfolio_service.override(
Expand Down Expand Up @@ -385,7 +430,6 @@ def _initialize_algorithm_for_backtest(self, algorithm):
market_credential_service = self.container.market_credential_service()
market_data_source_service = \
self.container.market_data_source_service()

# Initialize all services in the algorithm
algorithm.initialize_services(
configuration_service=self.container.configuration_service(),
Expand Down Expand Up @@ -444,17 +488,19 @@ def run(
raises an OperationalException. Then it initializes the algorithm
with the services and the configuration.
After the algorithm is initialized, it initializes the app and starts
the algorithm. If the app is running in stateless mode, it handles the
If the app is running in stateless mode, it handles the
payload. If the app is running in web mode, it starts the web app in a
separate thread.
:param payload: The payload to handle if the app is running in
stateless mode
:param number_of_iterations: The number of iterations to run the
algorithm for
:param sync: Whether to sync the portfolio with the exchange
:return: None
Args:
payload: The payload to handle if the app is running in
stateless mode
number_of_iterations: The number of iterations to run the
algorithm for
sync: Whether to sync the portfolio with the exchange
Returns:
None
"""

# Run all on_initialize hooks
Expand Down Expand Up @@ -676,21 +722,21 @@ def run_backtest(
Run a backtest for an algorithm. This method should be called when
running a backtest.
:param algorithm: The algorithm to run a backtest for (instance of
Algorithm)
:param backtest_date_range: The date range to run the backtest for
(instance of BacktestDateRange)
:param pending_order_check_interval: The interval at which to check
pending orders
:param output_directory: The directory to write the backtest report to
:return: Instance of BacktestReport
Args:
algorithm: The algorithm to run a backtest for (instance of
Algorithm)
backtest_date_range: The date range to run the backtest for
(instance of BacktestDateRange)
pending_order_check_interval: The interval at which to check
pending orders
output_directory: The directory to write the backtest report to
Returns:
Instance of BacktestReport
"""
logger.info("Initializing backtest")
self.algorithm = algorithm

market_data_sources = self._market_data_source_service\
.get_market_data_sources()

self._initialize_app_for_backtest(
backtest_date_range=backtest_date_range,
pending_order_check_interval=pending_order_check_interval,
Expand Down
76 changes: 76 additions & 0 deletions investing_algorithm_framework/app/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from investing_algorithm_framework.domain import \
TimeUnit, StrategyProfile, Trade
from .algorithm import Algorithm
import pandas as pd


class TradingStrategy:
Expand All @@ -11,6 +12,7 @@ class TradingStrategy:
strategy_id: str = None
decorated = None
market_data_sources = None
traces = None

def __init__(
self,
Expand Down Expand Up @@ -46,6 +48,8 @@ def __init__(

if strategy_id is not None:
self.strategy_id = strategy_id
else:
self.strategy_id = self.worker_id

# Check if time_unit is None
if self.time_unit is None:
Expand All @@ -59,6 +63,9 @@ def __init__(
f"Interval not set for strategy instance {self.strategy_id}"
)

# context initialization
self._context = None

def run_strategy(self, algorithm, market_data):
# Check pending orders before running the strategy
algorithm.check_pending_orders()
Expand Down Expand Up @@ -135,3 +142,72 @@ def strategy_identifier(self):
return self.strategy_id

return self.worker_id

@property
def context(self):
return self._context

@context.setter
def context(self, context):
self._context = context

def add_trace(
self,
symbol: str,
data,
drop_duplicates=True
) -> None:
"""
Add data to the straces object for a given symbol
Args:
symbol (str): The symbol
data (pd.DataFrame): The data to add to the tracing
drop_duplicates (bool): Drop duplicates
Returns:
None
"""

# Check if data is a DataFrame
if not isinstance(data, pd.DataFrame):
raise ValueError(
"Currently only pandas DataFrames are "
"supported as tracing data objects."
)

data: pd.DataFrame = data

# Check if index is a datetime object
if not isinstance(data.index, pd.DatetimeIndex):
raise ValueError("Dataframe Index must be a datetime object.")

if self.traces is None:
self.traces = {}

# Check if the key is already in the context dictionary
if symbol in self.traces:
# If the key is already in the context dictionary,
# append the new data to the existing data
combined = pd.concat([self.traces[symbol], data])
else:
# If the key is not in the context dictionary,
# add the new data to the context dictionary
combined = data

if drop_duplicates:
# Drop duplicates and sort the data by the index
combined = combined[~combined.index.duplicated(keep='first')]

# Set the datetime column as the index
combined.set_index(pd.DatetimeIndex(combined.index), inplace=True)
self.traces[symbol] = combined

def get_traces(self) -> dict:
"""
Get the traces object
Returns:
dict: The traces object
"""
return self.traces
6 changes: 6 additions & 0 deletions investing_algorithm_framework/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
load_backtest_report, \
csv_to_list, StoppableThread, pretty_print_backtest_reports_evaluation, \
pretty_print_backtest, load_csv_into_dict, load_backtest_reports
from .graphs import create_prices_graph, create_ema_graph, create_rsi_graph
from .metrics import get_price_efficiency_ratio

__all__ = [
'Config',
Expand Down Expand Up @@ -114,4 +116,8 @@
"RoundingService",
"BacktestDateRange",
"load_backtest_report",
"create_prices_graph",
"create_ema_graph",
"create_rsi_graph",
"get_price_efficiency_ratio"
]
Loading

0 comments on commit bb7d5f6

Please sign in to comment.