13
13
14
14
# Author: Gabriel Hurtado <[email protected] >
15
15
# License: BSD 3 clause
16
- import numpy as np
17
16
import pandas as pd
18
17
from .ts_feature_extractor import TimeSeriesFeatureExtractor
19
18
from .generative_regressor import GenerativeRegressor
20
19
20
+
21
21
class TSFEGenReg :
22
22
def __init__ (self ,
23
23
check_sizes , check_indexs , max_dists ,
@@ -50,6 +50,27 @@ def __init__(self,
50
50
check_sizes = check_sizes , check_indexs = check_indexs )
51
51
52
52
def train_submission (self , module_path , X_df , y_array , train_is = None ):
53
+ """Train model.
54
+
55
+ Parameters
56
+ ----------
57
+ module_path : string,
58
+ Path of the model to train.
59
+
60
+ X_df : pandas dataframe
61
+ Training data. Each sample contains data of a given timestep. Note
62
+ that the targets have to be included in the training samples as the
63
+ chaining rule is used: feature p - 1 of the target is needed to
64
+ predict feature p of the target.
65
+
66
+ y_array : numpy array, shape (n_samples,)
67
+ Training targets.
68
+
69
+ Returns
70
+ -------
71
+ fe, reg : tuple
72
+ Trained feature extractor and generative regressor.
73
+ """
53
74
54
75
# FE uses is o(t-1), a(t-1) concatenated without a(t)
55
76
# If train is none here, it still should not be a slice,
@@ -59,9 +80,9 @@ def train_submission(self, module_path, X_df, y_array, train_is=None):
59
80
module_path , X_df , y_array , train_is )
60
81
if train_is is None :
61
82
train_is = slice (None , None , None )
62
- cols_for_extraction = self .target_column_observation_names + \
63
- self .target_column_action_names + \
64
- self .restart_names
83
+ cols_for_extraction = ( self .target_column_observation_names +
84
+ self .target_column_action_names +
85
+ self .restart_names )
65
86
X_train_df = self .feature_extractor_workflow .test_submission (
66
87
fe , X_df [cols_for_extraction ][{self .timestamp_name : train_is }])
67
88
obs = ['y_' + obs for obs in self .target_column_observation_names ]
@@ -74,9 +95,9 @@ def test_submission(self, trained_model, X_df):
74
95
75
96
fe , reg = trained_model
76
97
77
- cols_for_extraction = self .target_column_observation_names + \
78
- self .target_column_action_names + \
79
- self .restart_names
98
+ cols_for_extraction = ( self .target_column_observation_names +
99
+ self .target_column_action_names +
100
+ self .restart_names )
80
101
81
102
X_test_df = self .feature_extractor_workflow .test_submission (
82
103
fe , X_df [cols_for_extraction ])
@@ -93,22 +114,42 @@ def test_submission(self, trained_model, X_df):
93
114
return y_pred_obs
94
115
95
116
def step (self , trained_model , X_df , seed = None ):
117
+ """Sample next observation.
118
+
119
+ The next observation is sampled from the trained model given a history.
120
+
121
+ Parameters
122
+ ----------
123
+ trained_model : tuple
124
+ Trained model returned by the train_submission method.
125
+
126
+ X_df : pandas dataframe
127
+ History used to sample the next observation.
128
+
129
+ For reinforcement learning, each sample of the history is assumed
130
+ to contain one observation and one action, the action being the one
131
+ selected after the observation. The action of the last row is the
132
+ one for which we want to sample the next observation.
133
+
134
+ Return
135
+ ------
136
+ sample_df : pandas dataframe
137
+ The next observation.
138
+ """
96
139
97
140
fe , reg = trained_model
98
141
99
- cols_for_extraction = self .target_column_observation_names + \
100
- self .target_column_action_names + \
101
- self .restart_names
142
+ cols_for_extraction = ( self .target_column_observation_names +
143
+ self .target_column_action_names +
144
+ self .restart_names )
102
145
103
146
X_test_array = self .feature_extractor_workflow .test_submission (
104
147
fe , X_df [cols_for_extraction ])
105
148
106
149
# We only care about sampling for the last provided timestep
107
150
X_test_array = X_test_array .iloc [- 1 ]
108
151
109
-
110
152
sampled = self .regressor_workflow .step (reg , X_test_array , seed )
111
-
112
153
sampled_df = pd .DataFrame (sampled )
113
154
114
155
new_names = self .target_column_observation_names
0 commit comments