@@ -295,26 +295,28 @@ def forward(self, x_in: Tuple):
295295 """
296296 data = x_in [0 ]
297297
298- start_token = self ._permute_transformer_inputs (data [:, - 1 :, :])
298+ pad_size = (0 , self .input_size - self .target_size )
299+ start_token = self ._permute_transformer_inputs (data [:, - 1 :, : self .target_size ])
300+ start_token_padded = F .pad (start_token , pad_size )
299301 if len (x_in ) == 3 :
300302 src , _ , tgt = x_in
301303 src = self ._permute_transformer_inputs (src )
302304 tgt_permuted = self ._permute_transformer_inputs (tgt )
303- tgt_padded = F .pad (tgt_permuted , ( 0 , self . input_size - self . target_size ) )
304- tgt = torch .cat ([start_token , tgt_padded ], dim = 0 )
305+ tgt_padded = F .pad (tgt_permuted , pad_size )
306+ tgt = torch .cat ([start_token_padded , tgt_padded ], dim = 0 )
305307 return self ._prediction_step (src , tgt )[:, :- 1 , :, :]
306308
307309 data , _ = x_in
308310
309311 src = self ._permute_transformer_inputs (data )
310- tgt = start_token
312+ tgt = start_token_padded
311313
312314 predictions = []
313315 for _ in range (self .output_chunk_length ):
314316 pred = self ._prediction_step (src , tgt )[:, - 1 , :, :]
315317 predictions .append (pred )
316318 tgt = torch .cat (
317- [tgt , pred .mean (dim = 2 ).unsqueeze (dim = 0 )],
319+ [tgt , F . pad ( pred .mean (dim = 2 ).unsqueeze (dim = 0 ), pad_size )],
318320 dim = 0 ,
319321 ) # take average of quantiles
320322 return torch .stack (predictions , dim = 1 )
0 commit comments