Skip to content

Commit

Permalink
test the k-folds operation of our TimeSeriesCrossValidator
Browse files Browse the repository at this point in the history
  • Loading branch information
tnixon committed Nov 5, 2024
1 parent 4f85463 commit 30680d6
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/tests/ml_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.regression import GBTRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql import DataFrame

from tempo.ml import TimeSeriesCrossValidator

Expand Down Expand Up @@ -131,6 +132,25 @@ def test_gap_param(self):
# check the gap
self.assertEqual(tscv.getGap(), 2)

def test_kfolds(self):
# load test data
trades_df = self.get_test_df_builder("trades").as_sdf()
# construct with default parameters
tscv = TimeSeriesCrossValidator(timeSeriesCol='event_ts',
seriesIdCols=['symbol'],
gap=0)
# test the k-folds
k_folds = tscv._kFold(trades_df)
# check the number of folds
self.assertEqual(len(k_folds), tscv.getNumFolds())
# check each fold
for fold in k_folds:
self.assertIsInstance(fold, tuple)
self.assertEqual(len(fold), 2)
self.assertIsInstance(fold[0], DataFrame)
self.assertIsInstance(fold[1], DataFrame)


# MAIN
if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions python/tests/unit_test_data/ml_tests.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"TimeSeriesCrossValidatorTests": {
"test_kfolds": {
"trades": {
"df": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_convert": ["event_ts"],
"data": "trades.csv"
}
}
}
}
}
101 changes: 101 additions & 0 deletions python/tests/unit_test_data/trades.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
symbol,event_ts,trade_pr
IBM,2017-08-31 00:57:25,347.9766055434685
IBM,2017-08-31 05:02:55,347.603478891568
IBM,2017-08-31 05:26:44,348.2851225377187
IBM,2017-08-31 05:38:08,347.8817054037267
IBM,2017-08-31 05:53:32,348.3718457507241
IBM,2017-08-31 06:56:22,349.40868952165323
IBM,2017-08-31 08:11:03,350.4640358206109
IBM,2017-08-31 10:49:09,347.716019602253
IBM,2017-08-31 11:01:38,347.2030920487126
IBM,2017-08-31 11:11:25,347.92907707949666
IBM,2017-08-31 11:49:55,346.1066922566784
IBM,2017-08-31 12:10:41,346.1236987198399
IBM,2017-08-31 13:02:47,349.20960037131124
IBM,2017-08-31 13:07:58,347.09158893690676
IBM,2017-08-31 14:15:49,347.45775383566854
IBM,2017-08-31 15:50:02,347.1668702661576
IBM,2017-08-31 17:27:50,348.56522298908044
IBM,2017-08-31 18:07:56,349.26325456538416
IBM,2017-08-31 19:09:47,349.34601689149946
IBM,2017-08-31 19:55:55,348.09936204319274
IBM,2017-08-31 20:17:15,347.1308847917395
IBM,2017-08-31 20:51:37,348.83766041227994
IBM,2017-08-31 21:37:17,348.4003780895007
K,2017-08-31 00:06:27,347.27138459233106
K,2017-08-31 00:18:46,347.9898553182071
K,2017-08-31 00:31:12,346.85852918073624
K,2017-08-31 00:51:16,346.91520445001134
K,2017-08-31 01:08:30,347.8078868655896
K,2017-08-31 01:34:54,347.2374835843108
K,2017-08-31 02:47:49,349.00659452619976
K,2017-08-31 02:49:22,347.4814105439092
K,2017-08-31 02:56:43,350.3539039043633
K,2017-08-31 03:01:33,349.5941805224711
K,2017-08-31 03:50:20,348.6119516556592
K,2017-08-31 03:52:18,348.18731148311406
K,2017-08-31 04:36:19,345.95795045531105
K,2017-08-31 05:27:12,346.6341114389929
K,2017-08-31 06:29:58,347.4121586706382
K,2017-08-31 06:32:30,346.7582132240916
K,2017-08-31 06:37:31,348.919146315238
K,2017-08-31 06:56:24,349.45235333868743
K,2017-08-31 08:38:22,347.6687817715506
K,2017-08-31 08:52:59,349.11648025163987
K,2017-08-31 09:22:55,347.16036576622395
K,2017-08-31 10:00:54,348.4869310969907
K,2017-08-31 10:52:36,348.44707325529976
K,2017-08-31 12:47:15,349.2617047407556
K,2017-08-31 13:17:24,349.16422862658777
K,2017-08-31 13:17:36,347.2034739832661
K,2017-08-31 13:42:17,350.3594725526159
K,2017-08-31 14:53:24,345.9384837375688
K,2017-08-31 15:14:08,346.3947630851533
K,2017-08-31 16:41:45,348.99202720361484
K,2017-08-31 18:41:52,348.7838699834772
K,2017-08-31 19:05:41,347.95173326760005
K,2017-08-31 19:25:27,348.16797905143034
K,2017-08-31 19:33:37,350.6567627351192
K,2017-08-31 20:21:47,347.9468144834939
K,2017-08-31 21:20:48,349.0419269428769
K,2017-08-31 21:36:07,347.38074751913484
K,2017-08-31 21:46:14,348.02539935462477
K,2017-08-31 21:58:11,346.98271245245644
K,2017-08-31 23:16:57,349.77827310811676
K,2017-08-31 23:29:40,348.9429200005411
KFS,2017-08-31 01:57:44,347.77347472191366
KFS,2017-08-31 01:58:19,347.3575869386784
KFS,2017-08-31 03:42:15,349.12235630639043
KFS,2017-08-31 10:26:57,347.9734526183446
KFS,2017-08-31 11:12:29,345.7111774398965
KFS,2017-08-31 12:30:27,347.9446791058658
KFS,2017-08-31 12:56:56,348.40914502757425
KFS,2017-08-31 20:18:30,348.5555420623246
KFS,2017-08-31 21:34:01,346.7731734554559
KFS,2017-08-31 22:32:59,348.6877379266723
KFS,2017-08-31 23:09:35,349.41137210604654
KFS,2017-08-31 23:11:17,349.0671659876273
KFS,2017-08-31 23:31:03,350.44123904624985
TBB,2017-08-31 00:54:21,347.0268200605267
TBB,2017-08-31 01:27:59,347.81625383701953
TBB,2017-08-31 01:29:59,346.7819013463641
TBB,2017-08-31 01:42:25,347.20721120029015
TBB,2017-08-31 02:28:56,347.4150394760788
TBB,2017-08-31 03:16:10,348.7008001367906
TBB,2017-08-31 04:38:04,348.0449984445236
TBB,2017-08-31 05:48:03,348.6731290332764
TBB,2017-08-31 08:32:07,350.7247367234809
TBB,2017-08-31 08:42:47,346.5096608964251
TBB,2017-08-31 10:58:42,348.4464129070117
TBB,2017-08-31 11:37:39,347.9739503215442
TBB,2017-08-31 12:25:31,349.7654451975011
TBB,2017-08-31 13:00:17,347.77438852748907
TBB,2017-08-31 14:46:22,348.6523007656035
TBB,2017-08-31 16:11:57,348.1998564265572
TBB,2017-08-31 16:54:51,347.86227977925466
TBB,2017-08-31 17:44:52,346.8702925232193
TBB,2017-08-31 18:26:52,347.85539454921854
TBB,2017-08-31 18:50:39,349.22132130112925
TBB,2017-08-31 19:03:36,346.8821233653525
TBB,2017-08-31 20:34:19,348.2391472198875
TBB,2017-08-31 20:36:40,347.0180283437618

0 comments on commit 30680d6

Please sign in to comment.