@@ -45,6 +45,8 @@ def __init__(
45
45
groups : int = 1 ,
46
46
bias : bool = True ,
47
47
padding_mode : PaddingMode = "zeros" ,
48
+ device = None ,
49
+ dtype = None ,
48
50
temporal_fill : PaddingMode = "zeros" ,
49
51
):
50
52
assert issubclass (
@@ -82,6 +84,8 @@ def __init__(
82
84
groups = groups ,
83
85
bias = bias ,
84
86
padding_mode = padding_mode ,
87
+ device = device ,
88
+ dtype = dtype ,
85
89
)
86
90
self .make_padding = {
87
91
PaddingMode .ZEROS .value : torch .zeros_like ,
@@ -142,7 +146,7 @@ def _forward_step(self, input: Tensor, prev_state: State) -> Tuple[Tensor, State
142
146
), f"A tensor of shape { (* self .input_shape_desciption [:2 ], * self .input_shape_desciption [3 :])} should be passed as input but got { input .shape } "
143
147
144
148
# e.g. B, C -> B, C, 1
145
- x = input .unsqueeze (2 )
149
+ x = input .unsqueeze (2 ). to ( device = self . weight . device )
146
150
147
151
if self .padding_mode == "zeros" :
148
152
x = self ._conv_func (
@@ -239,6 +243,8 @@ def __init__(
239
243
groups : int = 1 ,
240
244
bias : bool = True ,
241
245
padding_mode : PaddingMode = "zeros" ,
246
+ device = None ,
247
+ dtype = None ,
242
248
temporal_fill : PaddingMode = "zeros" ,
243
249
):
244
250
r"""Applies a continual 1D convolution over an input signal composed of several input
@@ -295,6 +301,8 @@ def __init__(
295
301
groups ,
296
302
bias ,
297
303
padding_mode ,
304
+ device ,
305
+ dtype ,
298
306
temporal_fill ,
299
307
)
300
308
@@ -338,6 +346,8 @@ def __init__(
338
346
groups : int = 1 ,
339
347
bias : bool = True ,
340
348
padding_mode : PaddingMode = "zeros" ,
349
+ device = None ,
350
+ dtype = None ,
341
351
temporal_fill : PaddingMode = "zeros" ,
342
352
):
343
353
r"""Applies a continual 2D convolution over an input signal composed of several input
@@ -394,6 +404,8 @@ def __init__(
394
404
groups ,
395
405
bias ,
396
406
padding_mode ,
407
+ device ,
408
+ dtype ,
397
409
temporal_fill ,
398
410
)
399
411
@@ -437,6 +449,8 @@ def __init__(
437
449
groups : int = 1 ,
438
450
bias : bool = True ,
439
451
padding_mode : PaddingMode = "zeros" ,
452
+ device = None ,
453
+ dtype = None ,
440
454
temporal_fill : PaddingMode = "zeros" ,
441
455
):
442
456
r"""Applies a continual 3D convolution over an input signal composed of several input
@@ -495,6 +509,8 @@ def __init__(
495
509
groups ,
496
510
bias ,
497
511
padding_mode ,
512
+ device ,
513
+ dtype ,
498
514
temporal_fill ,
499
515
)
500
516
0 commit comments