From f1b70a39c1ca113a79cd4a7954ee24a25e5cc023 Mon Sep 17 00:00:00 2001 From: marcvanduyn Date: Fri, 22 Mar 2024 13:18:42 +0100 Subject: [PATCH] Fix ohlcv data sources tests --- .../models/market_data_sources/csv.py | 4 +- .../test_csv_ohlcv_market_data_source.py | 48 ++++++++++--------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/investing_algorithm_framework/infrastructure/models/market_data_sources/csv.py b/investing_algorithm_framework/infrastructure/models/market_data_sources/csv.py index c8d59cdb..198d4f44 100644 --- a/investing_algorithm_framework/infrastructure/models/market_data_sources/csv.py +++ b/investing_algorithm_framework/infrastructure/models/market_data_sources/csv.py @@ -84,7 +84,7 @@ def get_data( # it's not already if 'Datetime' in df.columns and pd.api.types.is_string_dtype( df['Datetime']): - df['Datetime'] = pd.to_datetime(df['Datetime']) + df['Datetime'] = pd.to_datetime(df['Datetime'], utc=True) # Filter rows based on the start and end dates filtered_df = df[ @@ -157,7 +157,7 @@ def get_data(self, index_datetime=None, **kwargs): # it's not already if 'Datetime' in df.columns and pd.api.types.is_string_dtype( df['Datetime']): - df['Datetime'] = pd.to_datetime(df['Datetime']) + df['Datetime'] = pd.to_datetime(df['Datetime'], utc=True) # Filter rows based on the start and end dates filtered_df = df[(df['Datetime'] <= index_datetime)] diff --git a/tests/infrastructure/market_data_sources/test_csv_ohlcv_market_data_source.py b/tests/infrastructure/market_data_sources/test_csv_ohlcv_market_data_source.py index 53ef2a77..cb9a7f02 100644 --- a/tests/infrastructure/market_data_sources/test_csv_ohlcv_market_data_source.py +++ b/tests/infrastructure/market_data_sources/test_csv_ohlcv_market_data_source.py @@ -1,5 +1,7 @@ import os from datetime import datetime, timedelta + +from dateutil.tz import tzutc from unittest import TestCase from investing_algorithm_framework.infrastructure import \ @@ -26,8 +28,8 @@ def setUp(self) -> None: ) def test_right_columns(self): - file_name = "OHLCV_BTC-EUR_BINANCE_15m_2023-12-" \ - "01:00:00_2023-12-25:00:00.csv" + file_name = "OHLCV_BTC-EUR_BINANCE" \ + "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" CSVOHLCVMarketDataSource( csv_file_path=f"{self.resource_dir}/" "market_data_sources/" @@ -36,9 +38,9 @@ def test_right_columns(self): ) def test_start_date(self): - start_date = datetime(2023, 12, 1) - file_name = "OHLCV_BTC-EUR_BINANCE_15m_2023-12-" \ - "01:00:00_2023-12-25:00:00.csv" + start_date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc()) + file_name = "OHLCV_BTC-EUR_BINANCE" \ + "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" csv_ohlcv_market_data_source = CSVOHLCVMarketDataSource( csv_file_path=f"{self.resource_dir}/" "market_data_sources/" @@ -51,9 +53,9 @@ def test_start_date(self): ) def test_end_date(self): - end_date = datetime(2023, 12, 25) - file_name = "OHLCV_BTC-EUR_BINANCE_15m_2023-12-" \ - "01:00:00_2023-12-25:00:00.csv" + end_date = datetime(2023, 12, 2, 0, 0, tzinfo=tzutc()) + file_name = "OHLCV_BTC-EUR_BINANCE" \ + "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" csv_ohlcv_market_data_source = CSVOHLCVMarketDataSource( csv_file_path=f"{self.resource_dir}/" "market_data_sources/" @@ -66,9 +68,9 @@ def test_end_date(self): ) def test_empty(self): - start_date = datetime(2023, 12, 1) - file_name = "OHLCV_BTC-EUR_BINANCE_15m_2023-12-" \ - "01:00:00_2023-12-25:00:00.csv" + start_date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc()) + file_name = "OHLCV_BTC-EUR_BINANCE" \ + "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" data_source = CSVOHLCVMarketDataSource( csv_file_path=f"{self.resource_dir}/" "market_data_sources/" @@ -78,13 +80,13 @@ def test_empty(self): ) self.assertFalse(data_source.empty()) self.assertEqual(start_date, data_source.start_date) - data_source.start_date = datetime(2023, 12, 25) - data_source.end_date = datetime(2023, 12, 16) + data_source.start_date = datetime(2023, 12, 25, 0, 0, tzinfo=tzutc()) + data_source.end_date = datetime(2023, 12, 16, 0, 0, tzinfo=tzutc()) self.assertTrue(data_source.empty()) def test_get_data(self): - file_name = "OHLCV_BTC-EUR_BINANCE_15m_2023-12-" \ - "01:00:00_2023-12-25:00:00.csv" + file_name = "OHLCV_BTC-EUR_BINANCE" \ + "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" datasource = CSVOHLCVMarketDataSource( csv_file_path=f"{self.resource_dir}/" "market_data_sources/" @@ -104,8 +106,8 @@ def test_get_data(self): self.assertTrue(number_of_runs > 0) def test_get_identifier(self): - file_name = "OHLCV_BTC-EUR_BINANCE_15m_2023-12-" \ - "01:00:00_2023-12-25:00:00.csv" + file_name = "OHLCV_BTC-EUR_BINANCE" \ + "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" datasource = CSVOHLCVMarketDataSource( csv_file_path=f"{self.resource_dir}/" "market_data_sources/" @@ -117,8 +119,8 @@ def test_get_identifier(self): self.assertEqual("test", datasource.get_identifier()) def test_get_market(self): - file_name = "OHLCV_BTC-EUR_BINANCE_15m_2023-12-" \ - "01:00:00_2023-12-25:00:00.csv" + file_name = "OHLCV_BTC-EUR_BINANCE" \ + "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" datasource = CSVOHLCVMarketDataSource( csv_file_path=f"{self.resource_dir}/" "market_data_sources/" @@ -131,8 +133,8 @@ def test_get_market(self): self.assertEqual("test", datasource.get_market()) def test_get_symbol(self): - file_name = "OHLCV_BTC-EUR_BINANCE_15m_2023-12-" \ - "01:00:00_2023-12-25:00:00.csv" + file_name = "OHLCV_BTC-EUR_BINANCE" \ + "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" datasource = CSVOHLCVMarketDataSource( csv_file_path=f"{self.resource_dir}/" "market_data_sources/" @@ -144,8 +146,8 @@ def test_get_symbol(self): self.assertEqual("BTC/EUR", datasource.get_symbol()) def test_get_timeframe(self): - file_name = "OHLCV_BTC-EUR_BINANCE_15m_2023-12-" \ - "01:00:00_2023-12-25:00:00.csv" + file_name = "OHLCV_BTC-EUR_BINANCE" \ + "_2h_2023-08-07:07:59_2023-12-02:00:00.csv" datasource = CSVOHLCVMarketDataSource( csv_file_path=f"{self.resource_dir}/" "market_data_sources/"