Skip to content

Commit

Permalink
Refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MDUYN committed May 28, 2024
1 parent 9bcc907 commit c8c8acf
Showing 1 changed file with 65 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@
from datetime import datetime, timedelta
from unittest import TestCase

from dateutil import parser
from dateutil.tz import tzutc
from polars import DataFrame

from investing_algorithm_framework.domain import OperationalException
from investing_algorithm_framework.domain import OperationalException, \
TimeFrame, DATETIME_FORMAT
from investing_algorithm_framework.infrastructure import \
CSVOHLCVMarketDataSource


class Test(TestCase):
"""
Test cases for the CSVOHLCVMarketDataSource class.
"""

def setUp(self) -> None:
self.resource_dir = os.path.abspath(
Expand All @@ -33,13 +38,10 @@ def test_right_columns(self):
file_name = "OHLCV_BTC-EUR_BINANCE" \
"_2h_2023-08-07:07:59_2023-12-02:00:00.csv"
data_source = CSVOHLCVMarketDataSource(
csv_file_path=f"{self.resource_dir}/"
"market_data_sources/"
csv_file_path=f"{self.resource_dir}/market_data_sources/"
f"{file_name}",
window_size=10
)
df = data_source \
.get_data(datetime(year=2023, month=12, day=17, hour=0, minute=0))
df = data_source.get_data()
self.assertEqual(
["Datetime", "Open", "High", "Low", "Close", "Volume"], df.columns
)
Expand All @@ -65,12 +67,55 @@ def test_start_date(self):
"market_data_sources/"
f"{file_name}",
window_size=10,
timeframe=TimeFrame.TWO_HOUR,
)
self.assertEqual(
start_date,
csv_ohlcv_market_data_source.start_date.replace(microsecond=0),
)

def test_start_date_with_window_size(self):
start_date = datetime(
year=2023, month=8, day=7, hour=10, minute=0, tzinfo=tzutc()
)
file_name = "OHLCV_BTC-EUR_BINANCE" \
"_2h_2023-08-07:07:59_2023-12-02:00:00.csv"
csv_ohlcv_market_data_source = CSVOHLCVMarketDataSource(
csv_file_path=f"{self.resource_dir}/"
"market_data_sources/"
f"{file_name}",
window_size=12,
timeframe=TimeFrame.TWO_HOUR,
)
data = csv_ohlcv_market_data_source.get_data(start_date=start_date)
self.assertEqual(12, len(data))
first_date = parser.parse(data["Datetime"][0])
self.assertEqual(
start_date.strftime(DATETIME_FORMAT),
first_date.strftime(DATETIME_FORMAT)
)

def test_start_date_with_backtest_index_date(self):
start_date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc())
file_name = "OHLCV_BTC-EUR_BINANCE" \
"_2h_2023-08-07:07:59_2023-12-02:00:00.csv"
csv_ohlcv_market_data_source = CSVOHLCVMarketDataSource(
csv_file_path=f"{self.resource_dir}/"
"market_data_sources/"
f"{file_name}",
window_size=10,
timeframe=TimeFrame.TWO_HOUR,
)
data = csv_ohlcv_market_data_source.get_data(
backtest_index_date=datetime(2023, 8, 7, 8, 0, tzinfo=tzutc())
)
self.assertEqual(10, len(data))
first_date = parser.parse(data["Datetime"][0])
self.assertEqual(
start_date.strftime(DATETIME_FORMAT),
first_date.strftime(DATETIME_FORMAT)
)

def test_end_date(self):
end_date = datetime(2023, 12, 2, 0, 0, tzinfo=tzutc())
file_name = "OHLCV_BTC-EUR_BINANCE" \
Expand All @@ -79,29 +124,25 @@ def test_end_date(self):
csv_file_path=f"{self.resource_dir}/"
"market_data_sources/"
f"{file_name}",
window_size=10,
)
self.assertEqual(
end_date,
csv_ohlcv_market_data_source.end_date.replace(microsecond=0),
)

def test_empty(self):
start_date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc())
file_name = "OHLCV_BTC-EUR_BINANCE" \
"_2h_2023-08-07:07:59_2023-12-02:00:00.csv"
data_source = CSVOHLCVMarketDataSource(
csv_file_path=f"{self.resource_dir}/"
"market_data_sources/"
f"{file_name}",
window_size=10,
timeframe="15m",
timeframe="2h",
)
self.assertFalse(data_source.empty())
self.assertEqual(start_date, data_source.start_date)
data_source.start_date = datetime(2023, 12, 25, 0, 0, tzinfo=tzutc())
data_source.end_date = datetime(2023, 12, 16, 0, 0, tzinfo=tzutc())
self.assertTrue(data_source.empty())
start_date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc())
end_date = datetime(2023, 12, 2, 0, 0, tzinfo=tzutc())
self.assertFalse(data_source.empty(start_date, end_date))

def test_get_data(self):
file_name = "OHLCV_BTC-EUR_BINANCE" \
Expand All @@ -110,17 +151,16 @@ def test_get_data(self):
csv_file_path=f"{self.resource_dir}/"
"market_data_sources/"
f"{file_name}",
window_size=10,
timeframe="15m",
window_size=200,
timeframe="2h",
)
number_of_runs = 0
backtest_index_date = datasource.start_date

while not datasource.empty():
data = datasource.get_data()
datasource.start_date = datasource.start_date + timedelta(days=1)
datasource.end_date = datasource.end_date + timedelta(days=1)
data = datasource.get_data(backtest_index_date=backtest_index_date)
backtest_index_date = parser.parse(data["Datetime"][-1])
self.assertTrue(len(data) > 0)
self.assertAlmostEqual(10, len(data), 2)
self.assertTrue(isinstance(data, DataFrame))
number_of_runs += 1

Expand All @@ -135,7 +175,7 @@ def test_get_identifier(self):
f"{file_name}",
identifier="test",
window_size=10,
timeframe="15m",
timeframe="2h",
)
self.assertEqual("test", datasource.get_identifier())

Expand All @@ -147,9 +187,7 @@ def test_get_market(self):
"market_data_sources/"
f"{file_name}",
market="test",
timeframe="15m",
start_date=datetime(2023, 12, 1),
end_date=datetime(2023, 12, 25),
timeframe="2h",
)
self.assertEqual("test", datasource.get_market())

Expand All @@ -162,7 +200,7 @@ def test_get_symbol(self):
f"{file_name}",
symbol="BTC/EUR",
window_size=10,
timeframe="15m",
timeframe="2h",
)
self.assertEqual("BTC/EUR", datasource.get_symbol())

Expand All @@ -173,10 +211,7 @@ def test_get_timeframe(self):
csv_file_path=f"{self.resource_dir}/"
"market_data_sources/"
f"{file_name}",
timeframe="15m",
timeframe="2h",
window_size=10,
)
self.assertEqual("15m", datasource.get_timeframe())

def test_get_data(self):
pass
self.assertEqual("2h", datasource.get_timeframe())

0 comments on commit c8c8acf

Please sign in to comment.