Skip to content

Commit 69b8804

Browse files
Merge pull request #6 from LukasHedegaard/improved-nn-interop
Improved nn interop
2 parents 6539098 + b4d26f5 commit 69b8804

18 files changed

+340
-197
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88

9+
## [0.2.2]
10+
## Added
11+
- Automatic conversion of batch normalisation and activation functions
12+
13+
## Fixed
14+
- Separate dilation and stride in pool
15+
16+
## Changed
17+
- Conv forward to use temporal padding like (like nn.Conv)
18+
19+
## Removed
20+
- `co.BatchNorm2d`
21+
922
## [0.2.1]
1023
## Changed
1124
- Renamed `unsqueezed` to `forward_stepping`.

README.md

Lines changed: 150 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -28,46 +28,49 @@ Building blocks for Continual Inference Networks in PyTorch
2828
pip install continual-inference
2929
```
3030

31-
## Usage
31+
## Simple example
32+
Continual Modules are a weight-compatible drop-in replacement for torch.nn Modules, with the enhanced capability of efficient continual inference.
33+
3234
```python3
3335
import torch
34-
from torch import nn
3536
import continual as co
36-
# B, C, T, H, W
37-
example = torch.normal(mean=torch.zeros(5 * 3 * 3)).reshape((1, 1, 5, 3, 3))
37+
38+
# B, C, T, H, W
39+
example = torch.randn((1, 1, 5, 3, 3))
3840

39-
# Acts as a drop-in replacement for torch.nn modules ✅
40-
co_conv = co.Conv3d(in_channels=1, out_channels=1, kernel_size=(3, 3, 3))
41-
nn_conv = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(3, 3, 3))
42-
co_conv.load_state_dict(nn_conv.state_dict()) # ensure identical weights
41+
conv = co.Conv3d(in_channels=1, out_channels=1, kernel_size=(3, 3, 3))
4342

44-
co_output = co_conv(example) # Same exact computation
45-
nn_output = nn_conv(example) # Same exact computation
46-
assert torch.equal(co_output, nn_output)
43+
# Same exact computation as torch.nn.Conv3d ✅
44+
output = conv(example)
4745

4846
# But can also perform online inference efficiently 🚀
49-
firsts = co_conv.forward_steps(example[:, :, :4])
50-
last = co_conv.forward_step(example[:, :, 4])
47+
firsts = conv.forward_steps(example[:, :, :4])
48+
last = conv.forward_step(example[:, :, 4])
5149

52-
assert torch.allclose(nn_output[:, :, : co_conv.delay], firsts)
53-
assert torch.allclose(nn_output[:, :, co_conv.delay], last)
50+
assert torch.allclose(output[:, :, : conv.delay], firsts)
51+
assert torch.allclose(output[:, :, conv.delay], last)
5452
```
5553

54+
See also the "Advanced Examples" section.
55+
5656
## Continual Inference Networks (CINs)
5757
Continual Inference Networks are a type of neural network, which operate on a continual input stream of data and infer a new prediction for each new time-step.
58+
They are ideal for online detection and monitoring scenarios, but can also be used succesfully in offline situations.
5859

59-
All networks and network-modules, that do not utilise temporal information can be used for an Online Inference Network (e.g. `nn.Conv1d` and `nn.Conv2d` on spatial data such as an image).
60+
All networks and network-modules, that do not utilise temporal information can be used for an Continual Inference Network (e.g. `nn.Conv1d` and `nn.Conv2d` on spatial data such as an image).
6061
Moreover, recurrent modules (e.g. `LSTM` and `GRU`), that summarize past events in an internal state are also useable in CINs.
6162

63+
Some example CINs and non-CINs are illustrated below to
64+
6265
__CIN__:
6366
```
64-
O O O (output)
65-
66-
LSTM LSTM LSTM (temporal LSTM)
67-
68-
Conv2D Conv2D Conv2D (spatial 2D conv)
69-
70-
I I I (input frame)
67+
O O O (output)
68+
69+
nn.LSTM nn.LSTM nn.LSTM (temporal LSTM)
70+
71+
nn.Conv2D nn.Conv2D nn.Conv2D (spatial 2D conv)
72+
73+
I I I (input frame)
7174
```
7275

7376
However, modules that operate on temporal data with the assumption that the more temporal context is available than the current frame (e.g. the spatio-temporal `nn.Conv3d` used by many SotA video recognition models) cannot be directly applied.
@@ -76,7 +79,7 @@ __Not CIN__:
7679
```
7780
Θ (output)
7881
79-
Conv3D (spatio-temporal 3D conv)
82+
nn.Conv3D (spatio-temporal 3D conv)
8083
8184
----------------- (concatenate frames to clip)
8285
↑ ↑ ↑
@@ -87,70 +90,38 @@ Sometimes, though, the computations in such modules, can be cleverly restructure
8790

8891
__CIN__:
8992
```
90-
O O Θ (output)
91-
92-
ConvCo3D ConvCo3D ConvCo3D (continual spatio-temporal 3D conv)
93-
94-
I I I (input frame)
93+
O O Θ (output)
94+
95+
co.Conv3d co.Conv3d co.Conv3d (continual spatio-temporal 3D conv)
96+
97+
I I I (input frame)
9598
```
9699
Here, the `ϴ` output of the `Conv3D` and `ConvCo3D` are identical! ✨
97100

98-
## Modules
99-
This repository contains online inference-friendly versions of common network building blocks, inlcuding:
100-
101-
<!-- TODO: Replace with link to docs once they are set up -->
102-
- (Temporal) convolutions:
103-
- `co.Conv1d`
104-
- `co.Conv2d`
105-
- `co.Conv3d`
106-
107-
- (Temporal) batch normalisation:
108-
- `co.BatchNorm2d`
109-
110-
- (Temporal) pooling:
111-
- `co.AvgPool1d`
112-
- `co.AvgPool2d`
113-
- `co.AvgPool3d`
114-
- `co.MaxPool1d`
115-
- `co.MaxPool2d`
116-
- `co.MaxPool3d`
117-
- `co.AdaptiveAvgPool1d`
118-
- `co.AdaptiveAvgPool2d`
119-
- `co.AdaptiveAvgPool3d`
120-
- `co.AdaptiveMaxPool1d`
121-
- `co.AdaptiveMaxPool2d`
122-
- `co.AdaptiveMaxPool3d`
123-
124-
- Other
125-
- `co.Sequential` - sequential wrapper for modules
126-
- `co.Parallel` - parallel wrapper for modules
127-
- `co.Residual` - residual wrapper for modules
128-
- `co.Delay` - pure delay module
129-
<!-- - `co.Residual` - residual connection, which automatically adds delay if needed -->
130-
- `co.unsqueezed` - functional wrapper for non-continual modules
131-
- `co.continual` - conversion function from non-continual modules to continual moduls
101+
The last conversion from a non-CIN to a CIN is possible due to a recent break-through in Online Action Detection, namely [Continual Convolutions].
132102

133103
### Continual Convolutions
134-
Continual Convolutions can lead to major improvements in computational efficiency when online / frame-by-frame predictions are required.
135-
136-
Below, principle sketches comparing regular and continual convolutions are shown:
104+
Below, principle sketches are shown, which compare regular and continual convolutions during online / continual inference:
137105

138106
<div align="center">
139-
<img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/main/figures/continual-convolution.png" width="500">
107+
<img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/improved-nn-interop/figures/continual/regular-convolution.png" width="500">
140108
<br>
141109
Regular Convolution.
142-
A regular temporal convolutional layer leads to redundant computations during online processing of video clips, as illustrated by the repeated convolution of inputs (green b,c,d) with a kernel (blue α,β) in the temporal dimen- sion. Moreover, prior inputs (b,c,d) must be stored be- tween time-steps for online processing tasks.
143-
<br><br>
144-
<img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/main/figures/regular-convolution.png" width="500">
110+
A regular temporal convolutional layer leads to redundant computations during online processing of video clips, as illustrated by the repeated convolution of inputs (green b,c,d) with a kernel (blue α,β) in the temporal dimen- sion. Moreover, prior inputs (b,c,d) must be stored between time-steps for online processing tasks.
111+
<br><br>
112+
<img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/improved-nn-interop/figures/continual/continual-convolution.png" width="500">
145113
<br>
146114
Continual Convolution.
147115
An input (green d or e) is convolved with a kernel (blue α, β). The intermediary feature-maps corresponding to all but the last temporal position are stored, while the last feature map and prior memory are summed to produce the resulting output. For a continual stream of inputs, Continual Convolutions produce identical outputs to regular convolutions.
148-
<br><br>
116+
<br><br>
149117
</div>
150118

119+
As illustrated, Continual Convolutions can lead to major improvements in computational efficiency when online / frame-by-frame predictions are required! 🚀
120+
151121

152122
For more information, we refer to the [seminal paper on Continual Convolutions](https://arxiv.org/abs/2106.00050).
153123

124+
154125
## Forward modes
155126
The library components feature three distinct forward modes, which are handy for different situations.
156127

@@ -194,6 +165,114 @@ This method is handy for effient training on clip-based data.
194165
P I I I P (I: input frame, P: padding)
195166
```
196167

168+
169+
## Modules
170+
The repository contains custom online inference-friendly versions of common network building blocks, as well as handy wrappers and a global conversion function from `torch.nn` to `continual` (`co`) modules.
171+
172+
<!-- TODO: Replace with link to docs once they are set up -->
173+
- Convolutions:
174+
- `co.Conv1d`
175+
- `co.Conv2d`
176+
- `co.Conv3d`
177+
178+
- Pooling:
179+
- `co.AvgPool1d`
180+
- `co.AvgPool2d`
181+
- `co.AvgPool3d`
182+
- `co.MaxPool1d`
183+
- `co.MaxPool2d`
184+
- `co.MaxPool3d`
185+
- `co.AdaptiveAvgPool1d`
186+
- `co.AdaptiveAvgPool2d`
187+
- `co.AdaptiveAvgPool3d`
188+
- `co.AdaptiveMaxPool1d`
189+
- `co.AdaptiveMaxPool2d`
190+
- `co.AdaptiveMaxPool3d`
191+
192+
- Containers
193+
- `co.Sequential` - Sequential wrapper for modules. This module automatically performs conversions of torch.nn modules, which are safe during continual inference. These include all batch normalisation and activation function.
194+
- `co.Parallel` - Parallel wrapper for modules.
195+
- `co.Residual` - Residual wrapper for modules.
196+
- `co.Delay` - Pure delay module (e.g. needed in residuals).
197+
198+
- Converters
199+
<!-- - `co.Residual` - residual connection, which automatically adds delay if needed -->
200+
- `co.continual` - conversion function from non-continual modules to continual modules
201+
- `co.forward_stepping` - functional wrapper, which enhances temporally local non-continual modules with the forward_stepping functions
202+
203+
204+
## Advanced examples
205+
206+
### Continual 3D [MBConv](https://arxiv.org/pdf/1801.04381.pdf)
207+
208+
<div align="center">
209+
<img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/improved-nn-interop/figures/examples/mb_conv.png" width="150">
210+
<br>
211+
MobileNetV2 Inverted residual block. Source: https://arxiv.org/pdf/1801.04381.pdf
212+
</div>
213+
214+
```python3
215+
import continual as co
216+
from torch import nn
217+
218+
mb_conv = co.Residual(
219+
co.Sequential(
220+
co.Conv3d(32, 64, kernel_size=(1, 1, 1)),
221+
nn.BatchNorm3d(64),
222+
nn.ReLU6(),
223+
co.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(0, 1, 1), groups=64),
224+
nn.ReLU6(),
225+
co.Conv3d(64, 32, kernel_size=(1, 1, 1)),
226+
nn.BatchNorm3d(32),
227+
)
228+
)
229+
```
230+
231+
### Continual 3D [Inception module](https://arxiv.org/pdf/1409.4842v1.pdf)
232+
233+
<div align="center">
234+
<img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/improved-nn-interop/figures/examples/inception_block.png" width="450">
235+
<br>
236+
Inception module with dimension reductions. Source: https://arxiv.org/pdf/1409.4842v1.pdf
237+
</div>
238+
239+
```python3
240+
import continual as co
241+
from torch import nn
242+
243+
def norm_relu(module, channels):
244+
return co.Sequential(
245+
module,
246+
nn.BatchNorm3d(channels),
247+
nn.ReLU(),
248+
)
249+
250+
inception_module = co.Parallel(
251+
co.Conv3d(192, 64, kernel_size=1),
252+
co.Sequential(
253+
norm_relu(co.Conv3d(192, 96, kernel_size=1), 96),
254+
norm_relu(co.Conv3d(96, 128, kernel_size=3, padding=1), 128),
255+
),
256+
co.Sequential(
257+
norm_relu(co.Conv3d(192, 16, kernel_size=1), 16),
258+
norm_relu(co.Conv3d(16, 32, kernel_size=3, padding=1), 32),
259+
),
260+
co.Sequential(
261+
co.MaxPool3d(kernel_size=3, padding=1, stride=1),
262+
norm_relu(co.Conv3d(192, 32, kernel_size=1), 32),
263+
),
264+
aggregation_fn="concat",
265+
)
266+
```
267+
268+
269+
For additional full-fledged examples of complex Continual Inference Networks, see:
270+
271+
- [Continual 3D](https://github.com/LukasHedegaard/co3d)
272+
<!-- - [Continual Skeletons](https://github.com/LukasHedegaard/continual-skeletons) -->
273+
274+
275+
197276
## Compatibility
198277
The library modules are built to integrate seamlessly with other PyTorch projects.
199278
Specifically, extra care was taken to ensure out-of-the-box compatibility with:
@@ -202,13 +281,6 @@ Specifically, extra care was taken to ensure out-of-the-box compatibility with:
202281
- [ride](https://github.com/LukasHedegaard/ride)
203282

204283

205-
## Projects
206-
For full-fledged examples of complex Continual Inference Networks, see:
207-
208-
- [Continual 3D](https://github.com/LukasHedegaard/co3d)
209-
<!-- - [Continual Skeletons](https://github.com/LukasHedegaard/continual-skeletons) -->
210-
211-
212284
## Citations
213285
This library
214286
```bibtex

continual/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from .batchnorm import BatchNorm2d # noqa: F401
21
from .container import Parallel, Residual, Sequential # noqa: F401
32
from .conv import Conv1d, Conv2d, Conv3d # noqa: F401
43
from .convert import continual, forward_stepping # noqa: F401
@@ -16,4 +15,4 @@
1615
MaxPool2d,
1716
MaxPool3d,
1817
)
19-
from .ptflops import register_ptflops # noqa: F401
18+
from .ptflops import _register_ptflops # noqa: F401

0 commit comments

Comments
 (0)