Skip to content

Commit 30ccd78

Browse files
committed
add 0 padding to targets, to compensate for missing past covariates
1 parent 492c3e5 commit 30ccd78

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

darts/models/forecasting/transformer_model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,26 +295,28 @@ def forward(self, x_in: Tuple):
295295
"""
296296
data = x_in[0]
297297

298-
start_token = self._permute_transformer_inputs(data[:, -1:, :])
298+
pad_size = (0, self.input_size - self.target_size)
299+
start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size])
300+
start_token_padded = F.pad(start_token, pad_size)
299301
if len(x_in) == 3:
300302
src, _, tgt = x_in
301303
src = self._permute_transformer_inputs(src)
302304
tgt_permuted = self._permute_transformer_inputs(tgt)
303-
tgt_padded = F.pad(tgt_permuted, (0, self.input_size - self.target_size))
304-
tgt = torch.cat([start_token, tgt_padded], dim=0)
305+
tgt_padded = F.pad(tgt_permuted, pad_size)
306+
tgt = torch.cat([start_token_padded, tgt_padded], dim=0)
305307
return self._prediction_step(src, tgt)[:, :-1, :, :]
306308

307309
data, _ = x_in
308310

309311
src = self._permute_transformer_inputs(data)
310-
tgt = start_token
312+
tgt = start_token_padded
311313

312314
predictions = []
313315
for _ in range(self.output_chunk_length):
314316
pred = self._prediction_step(src, tgt)[:, -1, :, :]
315317
predictions.append(pred)
316318
tgt = torch.cat(
317-
[tgt, pred.mean(dim=2).unsqueeze(dim=0)],
319+
[tgt, F.pad(pred.mean(dim=2).unsqueeze(dim=0), pad_size)],
318320
dim=0,
319321
) # take average of quantiles
320322
return torch.stack(predictions, dim=1)

0 commit comments

Comments
 (0)