Skip to content

Commit 8e71ffe

Browse files
Merge pull request #59 from LukasHedegaard/develop
Add support for GroupNorm and InstanceNorm
2 parents 18027cf + 9bb757b commit 8e71ffe

File tree

5 files changed

+126
-4
lines changed

5 files changed

+126
-4
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.
88

99
## Unpublished
1010

11+
## [1.1.1] - 2023-01-10
12+
13+
### Added
14+
- Support for `GroupNorm` and `InstanceNorm`
15+
1116

1217
## [1.1.0] - 2022-12-19
1318

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,14 +456,19 @@ We support drop-in interoperability with with the following _torch.nn_ modules:
456456
- `nn.BatchNorm1d`
457457
- `nn.BatchNorm2d`
458458
- `nn.BatchNorm3d`
459-
- `nn.LayerNorm`
459+
- `nn.GroupNorm`,
460+
- `nn.InstanceNorm1d` (affine=True, track_running_stats=True required)
461+
- `nn.InstanceNorm2d` (affine=True, track_running_stats=True required)
462+
- `nn.InstanceNorm3d` (affine=True, track_running_stats=True required)
463+
- `nn.LayerNorm` (only non-temporal dimensions must be specified)
460464

461465
</details>
462466

463467
<details>
464468
<summary><b>Dropout</b></summary>
465469

466470
- `nn.Dropout`
471+
- `nn.Dropout1d`
467472
- `nn.Dropout2d`
468473
- `nn.Dropout3d`
469474
- `nn.AlphaDropout`

continual/__about__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import time
22

3-
__version__ = "1.1.0"
3+
__version__ = "1.1.1"
44
__author__ = "Lukas Hedegaard"
55
__author_email__ = "[email protected]"
66
__license__ = "Apache-2.0"
77
__copyright__ = f'Copyright (c) 2021-{time.strftime("%Y")}, {__author__}'
88
__homepage__ = "https://github.com/lukashedegaard/continual-inference"
9-
__docs__ = "Building blocks for Continual Inference Networks in PyTorch"
9+
__docs__ = "A Python library for Continual Inference Networks in PyTorch"

continual/convert.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from functools import wraps
44
from types import FunctionType
5-
from typing import Callable, Type
5+
from typing import Callable, Type, Union
66

77
from torch import Tensor, nn
88

@@ -150,6 +150,8 @@ def forward_with_callmode(*args, **kwargs):
150150
nn.BatchNorm1d,
151151
nn.BatchNorm2d,
152152
nn.BatchNorm3d,
153+
nn.LayerNorm,
154+
nn.GroupNorm,
153155
# >> Dropout modules
154156
nn.Dropout,
155157
nn.Dropout2d,
@@ -159,6 +161,28 @@ def forward_with_callmode(*args, **kwargs):
159161
}
160162

161163

164+
_circumvent_message = " to work with automatic conversion. You can circumvent this by wrapping the module in `co.forward_stepping(your_module)`. Note however, that this may break correspondence between forward and forward_step."
165+
166+
167+
def _instance_norm_condition(
168+
module: Union[nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d]
169+
):
170+
assert module.affine, (
171+
f"{type(module)} must be specified with `affine==True`" + _circumvent_message
172+
)
173+
assert module.track_running_stats, (
174+
f"{type(module)} must be specified with `track_running_stats==True`"
175+
+ _circumvent_message
176+
)
177+
178+
179+
CONDITIONAL_MAPPING = {
180+
nn.InstanceNorm1d: _instance_norm_condition,
181+
nn.InstanceNorm2d: _instance_norm_condition,
182+
nn.InstanceNorm3d: _instance_norm_condition,
183+
}
184+
185+
162186
class ModuleNotRegisteredError(Exception):
163187
...
164188

@@ -206,6 +230,10 @@ def continual(module: nn.Module) -> CoModule:
206230
if type(module) in NAIVE_MAPPING:
207231
return forward_stepping(module)
208232

233+
if type(module) in CONDITIONAL_MAPPING:
234+
CONDITIONAL_MAPPING[type(module)](module)
235+
return forward_stepping(module)
236+
209237
assert type(module) in MODULE_MAPPING, (
210238
f"A registered conversion for {module} was not found. "
211239
"You can register a custom conversion as follows:"

tests/continual/test_norm.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import torch
2+
from torch import nn
3+
4+
import continual as co
5+
6+
7+
def test_nn_norms():
8+
S = 3
9+
10+
long_example_clip = torch.normal(mean=torch.zeros(10 * 3 * 3)).reshape(
11+
(1, 1, 10, 3, 3)
12+
)
13+
14+
b_norm = nn.BatchNorm3d(1)
15+
b_norm.weight = nn.Parameter(3 * torch.ones_like(b_norm.weight))
16+
b_norm.bias = nn.Parameter(1 * torch.ones_like(b_norm.bias))
17+
18+
i_norm = nn.InstanceNorm3d(2, affine=True, track_running_stats=True)
19+
i_norm.weight = nn.Parameter(4 * torch.ones_like(i_norm.weight))
20+
i_norm.bias = nn.Parameter(2 * torch.ones_like(i_norm.bias))
21+
22+
l_norm = nn.LayerNorm([S, S]) # NB: Doesn't work over temporal axis
23+
l_norm.weight = nn.Parameter(5 * torch.ones_like(l_norm.weight))
24+
l_norm.bias = nn.Parameter(3 * torch.ones_like(l_norm.bias))
25+
26+
g_norm = nn.GroupNorm(2, 2)
27+
g_norm.weight = nn.Parameter(6 * torch.ones_like(g_norm.weight))
28+
g_norm.bias = nn.Parameter(4 * torch.ones_like(g_norm.bias))
29+
30+
seq = nn.Sequential(
31+
b_norm,
32+
nn.Conv3d(
33+
in_channels=1,
34+
out_channels=2,
35+
kernel_size=(5, S, S),
36+
bias=True,
37+
padding=(0, 1, 1),
38+
padding_mode="zeros",
39+
),
40+
i_norm,
41+
l_norm,
42+
g_norm,
43+
nn.Conv3d(
44+
in_channels=2,
45+
out_channels=1,
46+
kernel_size=(3, S, S),
47+
bias=True,
48+
padding=(0, 1, 1),
49+
padding_mode="zeros",
50+
),
51+
nn.MaxPool3d(kernel_size=(1, 2, 2)),
52+
)
53+
seq.eval()
54+
55+
coseq = co.Sequential.build_from(seq)
56+
coseq.eval()
57+
58+
assert coseq.delay == (5 - 1) + (3 - 1)
59+
60+
# forward
61+
output = seq.forward(long_example_clip)
62+
co_output = coseq.forward(long_example_clip)
63+
assert torch.allclose(output, co_output)
64+
65+
# forward_steps
66+
co_output_firsts_0 = coseq.forward_steps(
67+
long_example_clip[:, :, :-1], update_state=False
68+
)
69+
co_output_firsts = coseq.forward_steps(long_example_clip[:, :, :-1])
70+
assert torch.allclose(co_output_firsts, co_output_firsts_0, atol=1e-7)
71+
assert torch.allclose(co_output_firsts, output[:, :, :-1], atol=1e-7)
72+
73+
# forward_step
74+
co_output_last_0 = coseq.forward_step(
75+
long_example_clip[:, :, -1], update_state=False
76+
)
77+
co_output_last = coseq.forward_step(long_example_clip[:, :, -1])
78+
assert torch.allclose(co_output_last, co_output_last_0, atol=1e-7)
79+
assert torch.allclose(co_output_last, output[:, :, -1], atol=1e-7)
80+
81+
# Clean state can be used to restart seq computation
82+
coseq.clean_state()
83+
co_output_firsts = coseq.forward_steps(long_example_clip[:, :, :-1])
84+
assert torch.allclose(co_output_firsts, output[:, :, :-1], atol=1e-7)

0 commit comments

Comments
 (0)