Skip to content

Commit d77df0a

Browse files
committed
docstring for step and train + pep 8
1 parent 5f68f90 commit d77df0a

File tree

1 file changed

+53
-12
lines changed

1 file changed

+53
-12
lines changed

rampwf/workflows/ts_fe_gen_reg.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313

1414
# Author: Gabriel Hurtado <[email protected]>
1515
# License: BSD 3 clause
16-
import numpy as np
1716
import pandas as pd
1817
from .ts_feature_extractor import TimeSeriesFeatureExtractor
1918
from .generative_regressor import GenerativeRegressor
2019

20+
2121
class TSFEGenReg:
2222
def __init__(self,
2323
check_sizes, check_indexs, max_dists,
@@ -50,6 +50,27 @@ def __init__(self,
5050
check_sizes=check_sizes, check_indexs=check_indexs)
5151

5252
def train_submission(self, module_path, X_df, y_array, train_is=None):
53+
"""Train model.
54+
55+
Parameters
56+
----------
57+
module_path : string,
58+
Path of the model to train.
59+
60+
X_df : pandas dataframe
61+
Training data. Each sample contains data of a given timestep. Note
62+
that the targets have to be included in the training samples as the
63+
chaining rule is used: feature p - 1 of the target is needed to
64+
predict feature p of the target.
65+
66+
y_array : numpy array, shape (n_samples,)
67+
Training targets.
68+
69+
Returns
70+
-------
71+
fe, reg : tuple
72+
Trained feature extractor and generative regressor.
73+
"""
5374

5475
# FE uses is o(t-1), a(t-1) concatenated without a(t)
5576
# If train is none here, it still should not be a slice,
@@ -59,9 +80,9 @@ def train_submission(self, module_path, X_df, y_array, train_is=None):
5980
module_path, X_df, y_array, train_is)
6081
if train_is is None:
6182
train_is = slice(None, None, None)
62-
cols_for_extraction = self.target_column_observation_names + \
63-
self.target_column_action_names + \
64-
self.restart_names
83+
cols_for_extraction = (self.target_column_observation_names +
84+
self.target_column_action_names +
85+
self.restart_names)
6586
X_train_df = self.feature_extractor_workflow.test_submission(
6687
fe, X_df[cols_for_extraction][{self.timestamp_name: train_is}])
6788
obs = ['y_' + obs for obs in self.target_column_observation_names]
@@ -74,9 +95,9 @@ def test_submission(self, trained_model, X_df):
7495

7596
fe, reg = trained_model
7697

77-
cols_for_extraction = self.target_column_observation_names + \
78-
self.target_column_action_names + \
79-
self.restart_names
98+
cols_for_extraction = (self.target_column_observation_names +
99+
self.target_column_action_names +
100+
self.restart_names)
80101

81102
X_test_df = self.feature_extractor_workflow.test_submission(
82103
fe, X_df[cols_for_extraction])
@@ -93,22 +114,42 @@ def test_submission(self, trained_model, X_df):
93114
return y_pred_obs
94115

95116
def step(self, trained_model, X_df, seed=None):
117+
"""Sample next observation.
118+
119+
The next observation is sampled from the trained model given a history.
120+
121+
Parameters
122+
----------
123+
trained_model : tuple
124+
Trained model returned by the train_submission method.
125+
126+
X_df : pandas dataframe
127+
History used to sample the next observation.
128+
129+
For reinforcement learning, each sample of the history is assumed
130+
to contain one observation and one action, the action being the one
131+
selected after the observation. The action of the last row is the
132+
one for which we want to sample the next observation.
133+
134+
Return
135+
------
136+
sample_df : pandas dataframe
137+
The next observation.
138+
"""
96139

97140
fe, reg = trained_model
98141

99-
cols_for_extraction = self.target_column_observation_names + \
100-
self.target_column_action_names + \
101-
self.restart_names
142+
cols_for_extraction = (self.target_column_observation_names +
143+
self.target_column_action_names +
144+
self.restart_names)
102145

103146
X_test_array = self.feature_extractor_workflow.test_submission(
104147
fe, X_df[cols_for_extraction])
105148

106149
# We only care about sampling for the last provided timestep
107150
X_test_array = X_test_array.iloc[-1]
108151

109-
110152
sampled = self.regressor_workflow.step(reg, X_test_array, seed)
111-
112153
sampled_df = pd.DataFrame(sampled)
113154

114155
new_names = self.target_column_observation_names

0 commit comments

Comments
 (0)