1
1
from collections import OrderedDict
2
2
from enum import Enum
3
3
from functools import reduce , wraps
4
+ from numbers import Number
4
5
from typing import Callable , List , Optional , Sequence , Tuple , TypeVar , Union , overload
5
6
6
7
import torch
@@ -82,11 +83,11 @@ def wrapped(inputs: Sequence[Tensor]) -> Tensor:
82
83
return wrapped
83
84
84
85
85
- def int_from ( tuple_or_int : Union [int , Tuple [int , ...]], dim = 0 ) -> int :
86
- if isinstance (tuple_or_int , int ):
87
- return tuple_or_int
86
+ def num_from ( tuple_or_num : Union [Number , Tuple [Number , ...]], dim = 0 ) -> Number :
87
+ if isinstance (tuple_or_num , Number ):
88
+ return tuple_or_num
88
89
89
- return tuple_or_int [dim ]
90
+ return tuple_or_num [dim ]
90
91
91
92
92
93
class FlattenableStateDict :
@@ -206,8 +207,8 @@ def __init__(
206
207
]
207
208
208
209
assert (
209
- len (set (int_from (getattr (m , "stride" , 1 )) for _ , m in modules )) == 1
210
- ), f"Expected all modules to have the same stride, but got strides { [(int_from (getattr (m , 'stride' , 1 ))) for _ , m in modules ]} "
210
+ len (set (num_from (getattr (m , "stride" , 1 )) for _ , m in modules )) == 1
211
+ ), f"Expected all modules to have the same stride, but got strides { [(num_from (getattr (m , 'stride' , 1 ))) for _ , m in modules ]} "
211
212
212
213
for key , module in modules :
213
214
self .add_module (key , module )
@@ -253,11 +254,11 @@ def delay(self) -> int:
253
254
254
255
@property
255
256
def stride (self ) -> int :
256
- return int_from (getattr (next (iter (self )), "stride" , 1 ))
257
+ return num_from (getattr (next (iter (self )), "stride" , 1 ))
257
258
258
259
@property
259
260
def padding (self ) -> int :
260
- return max (int_from (getattr (m , "padding" , 0 )) for m in self )
261
+ return max (num_from (getattr (m , "padding" , 0 )) for m in self )
261
262
262
263
def clean_state (self ):
263
264
for m in self :
@@ -375,12 +376,12 @@ def delay(self):
375
376
def stride (self ) -> int :
376
377
tot = 1
377
378
for m in self :
378
- tot *= int_from (getattr (m , "stride" , 1 ))
379
+ tot *= num_from (getattr (m , "stride" , 1 ))
379
380
return tot
380
381
381
382
@property
382
383
def padding (self ) -> int :
383
- return max (int_from (getattr (m , "padding" , 0 )) for m in self )
384
+ return max (num_from (getattr (m , "padding" , 0 )) for m in self )
384
385
385
386
@staticmethod
386
387
def build_from (module : nn .Sequential ) -> "Sequential" :
@@ -466,8 +467,8 @@ def __init__(
466
467
]
467
468
468
469
assert (
469
- len (set (int_from (getattr (m , "stride" , 1 )) for _ , m in modules )) == 1
470
- ), f"Expected all modules to have the same stride, but got strides { [(int_from (getattr (m , 'stride' , 1 ))) for _ , m in modules ]} "
470
+ len (set (num_from (getattr (m , "stride" , 1 )) for _ , m in modules )) == 1
471
+ ), f"Expected all modules to have the same stride, but got strides { [(num_from (getattr (m , 'stride' , 1 ))) for _ , m in modules ]} "
471
472
472
473
for key , module in modules :
473
474
self .add_module (key , module )
@@ -542,11 +543,11 @@ def delay(self) -> int:
542
543
543
544
@property
544
545
def stride (self ) -> int :
545
- return int_from (getattr (next (iter (self )), "stride" , 1 ))
546
+ return num_from (getattr (next (iter (self )), "stride" , 1 ))
546
547
547
548
@property
548
549
def padding (self ) -> int :
549
- return max (int_from (getattr (m , "padding" , 0 )) for m in self )
550
+ return max (num_from (getattr (m , "padding" , 0 )) for m in self )
550
551
551
552
def clean_state (self ):
552
553
for m in self :
@@ -561,14 +562,18 @@ def Residual(
561
562
module : CoModule ,
562
563
temporal_fill : PaddingMode = None ,
563
564
reduce : Reduction = "sum" ,
565
+ forward_shrink : bool = False ,
564
566
):
567
+ assert num_from (getattr (module , "stride" , 1 )) == 1 , (
568
+ "The simple `Residual` only works for modules with temporal stride=1. "
569
+ "Complex residuals can be achieved using `BroadcastReduce` or the `Broadcast`, `Parallel`, and `Reduce` modules."
570
+ )
571
+ temporal_fill = temporal_fill or getattr (
572
+ module , "temporal_fill" , PaddingMode .REPLICATE .value
573
+ )
565
574
return BroadcastReduce (
566
575
# Residual first yields easier broadcasting in reduce functions
567
- Delay (
568
- delay = module .delay ,
569
- temporal_fill = temporal_fill
570
- or getattr (module , "temporal_fill" , PaddingMode .REPLICATE .value ),
571
- ),
576
+ Delay (module .delay , temporal_fill , forward_shrink ),
572
577
module ,
573
578
reduce = reduce ,
574
579
auto_delay = False ,
0 commit comments