Skip to content

Commit fd13830

Browse files
committed
add comments
1 parent 30ccd78 commit fd13830

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

darts/models/forecasting/transformer_model.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -289,15 +289,16 @@ def _permute_transformer_inputs(self, data):
289289

290290
def forward(self, x_in: Tuple):
291291
"""
292-
When teacher forcing, x_in = (past_target + past_covariates, static_covariates, future_targets)
293-
When inference, x_in = (past_target + past_covariates, static_covariates)
294-
292+
During training (teacher forcing) x_in = tuple(past_target + past_covariates, static_covariates, future_targets)
293+
During inference x_in = tuple(past_target + past_covariates, static_covariates)
295294
"""
296295
data = x_in[0]
297-
298296
pad_size = (0, self.input_size - self.target_size)
297+
298+
# start token consists only of target series, past covariates are substituted with 0 padding
299299
start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size])
300300
start_token_padded = F.pad(start_token, pad_size)
301+
301302
if len(x_in) == 3:
302303
src, _, tgt = x_in
303304
src = self._permute_transformer_inputs(src)
@@ -340,7 +341,7 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor):
340341
# Here we change the data format
341342
# from (1, batch_size, output_chunk_length * output_size)
342343
# to (batch_size, output_chunk_length, output_size, nr_params)
343-
predictions = out.permute(1, 0, 2)
344+
predictions = self._permute_transformer_inputs(out)
344345
predictions = predictions.view(batch_size, -1, self.target_size, self.nr_params)
345346

346347
return predictions
@@ -357,13 +358,11 @@ def _produce_train_output(self, input_batch: Tuple):
357358
Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for
358359
training.
359360
360-
Parameters: if len(inp
361-
# print([x.shape if x is not None else x for x in train_batch], "TRAIN")ut_batch) != 4:
362-
a = 1
363-
----------
364-
a = 1
361+
Parameters:
365362
input_batch
366-
``(past_target, past_covariates, static_covariates)``
363+
``(past_target, past_covariates, static_covariates, future_target)`` during training
364+
365+
``(past_target, past_covariates, static_covariates)`` during validation (not teacher forced)
367366
"""
368367

369368
past_target, past_covariates, static_covariates = input_batch[:3]
@@ -375,7 +374,7 @@ def _produce_train_output(self, input_batch: Tuple):
375374
static_covariates,
376375
]
377376

378-
# add future targets when teacher forcing
377+
# add future targets when training (teacher forcing)
379378
if len(input_batch) == 4:
380379
inpt.append(input_batch[-1])
381380
return self(inpt)

0 commit comments

Comments
 (0)