From c8c8acfb7a45d14411c38caf781ec447fe492624 Mon Sep 17 00:00:00 2001 From: marcvanduyn Date: Tue, 28 May 2024 14:13:23 +0200 Subject: [PATCH] Refactor tests --- .../test_csv_ohlcv_market_data_source.py | 95 +++++++++++++------ 1 file changed, 65 insertions(+), 30 deletions(-) diff --git a/tests/infrastructure/market_data_sources/test_csv_ohlcv_market_data_source.py b/tests/infrastructure/market_data_sources/test_csv_ohlcv_market_data_source.py index efd499ca..40761b19 100644 --- a/tests/infrastructure/market_data_sources/test_csv_ohlcv_market_data_source.py +++ b/tests/infrastructure/market_data_sources/test_csv_ohlcv_market_data_source.py @@ -2,15 +2,20 @@ from datetime import datetime, timedelta from unittest import TestCase +from dateutil import parser from dateutil.tz import tzutc from polars import DataFrame -from investing_algorithm_framework.domain import OperationalException +from investing_algorithm_framework.domain import OperationalException, \ + TimeFrame, DATETIME_FORMAT from investing_algorithm_framework.infrastructure import \ CSVOHLCVMarketDataSource class Test(TestCase): + """ + Test cases for the CSVOHLCVMarketDataSource class. + """ def setUp(self) -> None: self.resource_dir = os.path.abspath( @@ -33,13 +38,10 @@ def test_right_columns(self): file_name = "OHLCV_BTC-EUR_BINANCE" \ "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" data_source = CSVOHLCVMarketDataSource( - csv_file_path=f"{self.resource_dir}/" - "market_data_sources/" + csv_file_path=f"{self.resource_dir}/market_data_sources/" f"{file_name}", - window_size=10 ) - df = data_source \ - .get_data(datetime(year=2023, month=12, day=17, hour=0, minute=0)) + df = data_source.get_data() self.assertEqual( ["Datetime", "Open", "High", "Low", "Close", "Volume"], df.columns ) @@ -65,12 +67,55 @@ def test_start_date(self): "market_data_sources/" f"{file_name}", window_size=10, + timeframe=TimeFrame.TWO_HOUR, ) self.assertEqual( start_date, csv_ohlcv_market_data_source.start_date.replace(microsecond=0), ) + def test_start_date_with_window_size(self): + start_date = datetime( + year=2023, month=8, day=7, hour=10, minute=0, tzinfo=tzutc() + ) + file_name = "OHLCV_BTC-EUR_BINANCE" \ + "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" + csv_ohlcv_market_data_source = CSVOHLCVMarketDataSource( + csv_file_path=f"{self.resource_dir}/" + "market_data_sources/" + f"{file_name}", + window_size=12, + timeframe=TimeFrame.TWO_HOUR, + ) + data = csv_ohlcv_market_data_source.get_data(start_date=start_date) + self.assertEqual(12, len(data)) + first_date = parser.parse(data["Datetime"][0]) + self.assertEqual( + start_date.strftime(DATETIME_FORMAT), + first_date.strftime(DATETIME_FORMAT) + ) + + def test_start_date_with_backtest_index_date(self): + start_date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc()) + file_name = "OHLCV_BTC-EUR_BINANCE" \ + "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" + csv_ohlcv_market_data_source = CSVOHLCVMarketDataSource( + csv_file_path=f"{self.resource_dir}/" + "market_data_sources/" + f"{file_name}", + window_size=10, + timeframe=TimeFrame.TWO_HOUR, + ) + data = csv_ohlcv_market_data_source.get_data( + backtest_index_date=datetime(2023, 8, 7, 8, 0, tzinfo=tzutc()) + ) + self.assertEqual(10, len(data)) + first_date = parser.parse(data["Datetime"][0]) + self.assertEqual( + start_date.strftime(DATETIME_FORMAT), + first_date.strftime(DATETIME_FORMAT) + ) + def test_end_date(self): end_date = datetime(2023, 12, 2, 0, 0, tzinfo=tzutc()) file_name = "OHLCV_BTC-EUR_BINANCE" \ @@ -79,7 +124,6 @@ def test_end_date(self): csv_file_path=f"{self.resource_dir}/" "market_data_sources/" f"{file_name}", - window_size=10, ) self.assertEqual( end_date, @@ -87,7 +131,6 @@ def test_end_date(self): ) def test_empty(self): - start_date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc()) file_name = "OHLCV_BTC-EUR_BINANCE" \ "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" data_source = CSVOHLCVMarketDataSource( @@ -95,13 +138,11 @@ def test_empty(self): "market_data_sources/" f"{file_name}", window_size=10, - timeframe="15m", + timeframe="2h", ) - self.assertFalse(data_source.empty()) - self.assertEqual(start_date, data_source.start_date) - data_source.start_date = datetime(2023, 12, 25, 0, 0, tzinfo=tzutc()) - data_source.end_date = datetime(2023, 12, 16, 0, 0, tzinfo=tzutc()) - self.assertTrue(data_source.empty()) + start_date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc()) + end_date = datetime(2023, 12, 2, 0, 0, tzinfo=tzutc()) + self.assertFalse(data_source.empty(start_date, end_date)) def test_get_data(self): file_name = "OHLCV_BTC-EUR_BINANCE" \ @@ -110,17 +151,16 @@ def test_get_data(self): csv_file_path=f"{self.resource_dir}/" "market_data_sources/" f"{file_name}", - window_size=10, - timeframe="15m", + window_size=200, + timeframe="2h", ) number_of_runs = 0 + backtest_index_date = datasource.start_date while not datasource.empty(): - data = datasource.get_data() - datasource.start_date = datasource.start_date + timedelta(days=1) - datasource.end_date = datasource.end_date + timedelta(days=1) + data = datasource.get_data(backtest_index_date=backtest_index_date) + backtest_index_date = parser.parse(data["Datetime"][-1]) self.assertTrue(len(data) > 0) - self.assertAlmostEqual(10, len(data), 2) self.assertTrue(isinstance(data, DataFrame)) number_of_runs += 1 @@ -135,7 +175,7 @@ def test_get_identifier(self): f"{file_name}", identifier="test", window_size=10, - timeframe="15m", + timeframe="2h", ) self.assertEqual("test", datasource.get_identifier()) @@ -147,9 +187,7 @@ def test_get_market(self): "market_data_sources/" f"{file_name}", market="test", - timeframe="15m", - start_date=datetime(2023, 12, 1), - end_date=datetime(2023, 12, 25), + timeframe="2h", ) self.assertEqual("test", datasource.get_market()) @@ -162,7 +200,7 @@ def test_get_symbol(self): f"{file_name}", symbol="BTC/EUR", window_size=10, - timeframe="15m", + timeframe="2h", ) self.assertEqual("BTC/EUR", datasource.get_symbol()) @@ -173,10 +211,7 @@ def test_get_timeframe(self): csv_file_path=f"{self.resource_dir}/" "market_data_sources/" f"{file_name}", - timeframe="15m", + timeframe="2h", window_size=10, ) - self.assertEqual("15m", datasource.get_timeframe()) - - def test_get_data(self): - pass + self.assertEqual("2h", datasource.get_timeframe())