@@ -289,15 +289,16 @@ def _permute_transformer_inputs(self, data):
289289
290290 def forward (self , x_in : Tuple ):
291291 """
292- When teacher forcing, x_in = (past_target + past_covariates, static_covariates, future_targets)
293- When inference, x_in = (past_target + past_covariates, static_covariates)
294-
292+ During training (teacher forcing) x_in = tuple(past_target + past_covariates, static_covariates, future_targets)
293+ During inference x_in = tuple(past_target + past_covariates, static_covariates)
295294 """
296295 data = x_in [0 ]
297-
298296 pad_size = (0 , self .input_size - self .target_size )
297+
298+ # start token consists only of target series, past covariates are substituted with 0 padding
299299 start_token = self ._permute_transformer_inputs (data [:, - 1 :, : self .target_size ])
300300 start_token_padded = F .pad (start_token , pad_size )
301+
301302 if len (x_in ) == 3 :
302303 src , _ , tgt = x_in
303304 src = self ._permute_transformer_inputs (src )
@@ -340,7 +341,7 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor):
340341 # Here we change the data format
341342 # from (1, batch_size, output_chunk_length * output_size)
342343 # to (batch_size, output_chunk_length, output_size, nr_params)
343- predictions = out . permute ( 1 , 0 , 2 )
344+ predictions = self . _permute_transformer_inputs ( out )
344345 predictions = predictions .view (batch_size , - 1 , self .target_size , self .nr_params )
345346
346347 return predictions
@@ -357,13 +358,11 @@ def _produce_train_output(self, input_batch: Tuple):
357358 Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for
358359 training.
359360
360- Parameters: if len(inp
361- # print([x.shape if x is not None else x for x in train_batch], "TRAIN")ut_batch) != 4:
362- a = 1
363- ----------
364- a = 1
361+ Parameters:
365362 input_batch
366- ``(past_target, past_covariates, static_covariates)``
363+ ``(past_target, past_covariates, static_covariates, future_target)`` during training
364+
365+ ``(past_target, past_covariates, static_covariates)`` during validation (not teacher forced)
367366 """
368367
369368 past_target , past_covariates , static_covariates = input_batch [:3 ]
@@ -375,7 +374,7 @@ def _produce_train_output(self, input_batch: Tuple):
375374 static_covariates ,
376375 ]
377376
378- # add future targets when teacher forcing
377+ # add future targets when training ( teacher forcing)
379378 if len (input_batch ) == 4 :
380379 inpt .append (input_batch [- 1 ])
381380 return self (inpt )
0 commit comments