You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Continual Inference Networks are a type of neural network, which operate on a continual input stream of data and infer a new prediction for each new time-step.
58
+
They are ideal for online detection and monitoring scenarios, but can also be used succesfully in offline situations.
58
59
59
-
All networks and network-modules, that do not utilise temporal information can be used for an Online Inference Network (e.g. `nn.Conv1d` and `nn.Conv2d` on spatial data such as an image).
60
+
All networks and network-modules, that do not utilise temporal information can be used for an Continual Inference Network (e.g. `nn.Conv1d` and `nn.Conv2d` on spatial data such as an image).
60
61
Moreover, recurrent modules (e.g. `LSTM` and `GRU`), that summarize past events in an internal state are also useable in CINs.
61
62
63
+
Some example CINs and non-CINs are illustrated below to
64
+
62
65
__CIN__:
63
66
```
64
-
O O O (output)
65
-
↑ ↑ ↑
66
-
LSTM LSTM LSTM (temporal LSTM)
67
-
↑ ↑ ↑
68
-
Conv2D Conv2D Conv2D (spatial 2D conv)
69
-
↑ ↑ ↑
70
-
I I I (input frame)
67
+
O O O (output)
68
+
↑ ↑ ↑
69
+
nn.LSTM nn.LSTM nn.LSTM (temporal LSTM)
70
+
↑ ↑ ↑
71
+
nn.Conv2D nn.Conv2D nn.Conv2D (spatial 2D conv)
72
+
↑ ↑ ↑
73
+
I I I (input frame)
71
74
```
72
75
73
76
However, modules that operate on temporal data with the assumption that the more temporal context is available than the current frame (e.g. the spatio-temporal `nn.Conv3d` used by many SotA video recognition models) cannot be directly applied.
@@ -76,7 +79,7 @@ __Not CIN__:
76
79
```
77
80
Θ (output)
78
81
↑
79
-
Conv3D (spatio-temporal 3D conv)
82
+
nn.Conv3D (spatio-temporal 3D conv)
80
83
↑
81
84
----------------- (concatenate frames to clip)
82
85
↑ ↑ ↑
@@ -87,70 +90,38 @@ Sometimes, though, the computations in such modules, can be cleverly restructure
87
90
88
91
__CIN__:
89
92
```
90
-
O O Θ (output)
91
-
↑ ↑ ↑
92
-
ConvCo3D ConvCo3D ConvCo3D (continual spatio-temporal 3D conv)
93
-
↑ ↑ ↑
94
-
I I I (input frame)
93
+
O O Θ (output)
94
+
↑ ↑ ↑
95
+
co.Conv3d co.Conv3d co.Conv3d (continual spatio-temporal 3D conv)
96
+
↑ ↑ ↑
97
+
I I I (input frame)
95
98
```
96
99
Here, the `ϴ` output of the `Conv3D` and `ConvCo3D` are identical! ✨
97
100
98
-
## Modules
99
-
This repository contains online inference-friendly versions of common network building blocks, inlcuding:
100
-
101
-
<!-- TODO: Replace with link to docs once they are set up -->
102
-
- (Temporal) convolutions:
103
-
-`co.Conv1d`
104
-
-`co.Conv2d`
105
-
-`co.Conv3d`
106
-
107
-
- (Temporal) batch normalisation:
108
-
-`co.BatchNorm2d`
109
-
110
-
- (Temporal) pooling:
111
-
-`co.AvgPool1d`
112
-
-`co.AvgPool2d`
113
-
-`co.AvgPool3d`
114
-
-`co.MaxPool1d`
115
-
-`co.MaxPool2d`
116
-
-`co.MaxPool3d`
117
-
-`co.AdaptiveAvgPool1d`
118
-
-`co.AdaptiveAvgPool2d`
119
-
-`co.AdaptiveAvgPool3d`
120
-
-`co.AdaptiveMaxPool1d`
121
-
-`co.AdaptiveMaxPool2d`
122
-
-`co.AdaptiveMaxPool3d`
123
-
124
-
- Other
125
-
-`co.Sequential` - sequential wrapper for modules
126
-
-`co.Parallel` - parallel wrapper for modules
127
-
-`co.Residual` - residual wrapper for modules
128
-
-`co.Delay` - pure delay module
129
-
<!-- - `co.Residual` - residual connection, which automatically adds delay if needed -->
130
-
-`co.unsqueezed` - functional wrapper for non-continual modules
131
-
-`co.continual` - conversion function from non-continual modules to continual moduls
101
+
The last conversion from a non-CIN to a CIN is possible due to a recent break-through in Online Action Detection, namely [Continual Convolutions].
132
102
133
103
### Continual Convolutions
134
-
Continual Convolutions can lead to major improvements in computational efficiency when online / frame-by-frame predictions are required.
135
-
136
-
Below, principle sketches comparing regular and continual convolutions are shown:
104
+
Below, principle sketches are shown, which compare regular and continual convolutions during online / continual inference:
A regular temporal convolutional layer leads to redundant computations during online processing of video clips, as illustrated by the repeated convolution of inputs (green b,c,d) with a kernel (blue α,β) in the temporal dimen- sion. Moreover, prior inputs (b,c,d) must be stored be- tween time-steps for online processing tasks.
A regular temporal convolutional layer leads to redundant computations during online processing of video clips, as illustrated by the repeated convolution of inputs (green b,c,d) with a kernel (blue α,β) in the temporal dimen- sion. Moreover, prior inputs (b,c,d) must be stored between time-steps for online processing tasks.
An input (green d or e) is convolved with a kernel (blue α, β). The intermediary feature-maps corresponding to all but the last temporal position are stored, while the last feature map and prior memory are summed to produce the resulting output. For a continual stream of inputs, Continual Convolutions produce identical outputs to regular convolutions.
148
-
<br><br>
116
+
<br><br>
149
117
</div>
150
118
119
+
As illustrated, Continual Convolutions can lead to major improvements in computational efficiency when online / frame-by-frame predictions are required! 🚀
120
+
151
121
152
122
For more information, we refer to the [seminal paper on Continual Convolutions](https://arxiv.org/abs/2106.00050).
153
123
124
+
154
125
## Forward modes
155
126
The library components feature three distinct forward modes, which are handy for different situations.
156
127
@@ -194,6 +165,114 @@ This method is handy for effient training on clip-based data.
194
165
P I I I P (I: input frame, P: padding)
195
166
```
196
167
168
+
169
+
## Modules
170
+
The repository contains custom online inference-friendly versions of common network building blocks, as well as handy wrappers and a global conversion function from `torch.nn` to `continual` (`co`) modules.
171
+
172
+
<!-- TODO: Replace with link to docs once they are set up -->
173
+
- Convolutions:
174
+
-`co.Conv1d`
175
+
-`co.Conv2d`
176
+
-`co.Conv3d`
177
+
178
+
- Pooling:
179
+
-`co.AvgPool1d`
180
+
-`co.AvgPool2d`
181
+
-`co.AvgPool3d`
182
+
-`co.MaxPool1d`
183
+
-`co.MaxPool2d`
184
+
-`co.MaxPool3d`
185
+
-`co.AdaptiveAvgPool1d`
186
+
-`co.AdaptiveAvgPool2d`
187
+
-`co.AdaptiveAvgPool3d`
188
+
-`co.AdaptiveMaxPool1d`
189
+
-`co.AdaptiveMaxPool2d`
190
+
-`co.AdaptiveMaxPool3d`
191
+
192
+
- Containers
193
+
-`co.Sequential` - Sequential wrapper for modules. This module automatically performs conversions of torch.nn modules, which are safe during continual inference. These include all batch normalisation and activation function.
194
+
-`co.Parallel` - Parallel wrapper for modules.
195
+
-`co.Residual` - Residual wrapper for modules.
196
+
-`co.Delay` - Pure delay module (e.g. needed in residuals).
197
+
198
+
- Converters
199
+
<!-- - `co.Residual` - residual connection, which automatically adds delay if needed -->
200
+
-`co.continual` - conversion function from non-continual modules to continual modules
201
+
-`co.forward_stepping` - functional wrapper, which enhances temporally local non-continual modules with the forward_stepping functions
202
+
203
+
204
+
## Advanced examples
205
+
206
+
### Continual 3D [MBConv](https://arxiv.org/pdf/1801.04381.pdf)
0 commit comments