Skip to content

Commit ff66157

Browse files
Merge pull request #33 from LukasHedegaard/develop
Add `forward_shrink` option to `Delay` and `Residual`
2 parents 9aa23ac + 605ec65 commit ff66157

File tree

6 files changed

+61
-22
lines changed

6 files changed

+61
-22
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.
99

1010
## [Unreleased]
1111

12+
13+
## [0.13.0]
14+
### Added
15+
- Add `forward_shrink` option to `Delay` and `Residual`.
16+
17+
1218
## [0.12.0]
1319
### Added
1420
- Add `Constant`.

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ Below is a list of the modules and utilities included in the library:
215215
- `co.Add` - Adds a constant value.
216216
- `co.Multiply` - Multiplies with a constant factor.
217217
- `co.Unity` - Maps input to output without modification.
218+
- `co.Constant` - Maps input to and output with constant value.
219+
- `co.Zero` - Maps input to output of zeros.
220+
- `co.One` - Maps input to output of ones.
218221

219222
- Converters
220223
<!-- - `co.Residual` - residual connection, which automatically adds delay if needed -->

continual/container.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import OrderedDict
22
from enum import Enum
33
from functools import reduce, wraps
4+
from numbers import Number
45
from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union, overload
56

67
import torch
@@ -82,11 +83,11 @@ def wrapped(inputs: Sequence[Tensor]) -> Tensor:
8283
return wrapped
8384

8485

85-
def int_from(tuple_or_int: Union[int, Tuple[int, ...]], dim=0) -> int:
86-
if isinstance(tuple_or_int, int):
87-
return tuple_or_int
86+
def num_from(tuple_or_num: Union[Number, Tuple[Number, ...]], dim=0) -> Number:
87+
if isinstance(tuple_or_num, Number):
88+
return tuple_or_num
8889

89-
return tuple_or_int[dim]
90+
return tuple_or_num[dim]
9091

9192

9293
class FlattenableStateDict:
@@ -206,8 +207,8 @@ def __init__(
206207
]
207208

208209
assert (
209-
len(set(int_from(getattr(m, "stride", 1)) for _, m in modules)) == 1
210-
), f"Expected all modules to have the same stride, but got strides {[(int_from(getattr(m, 'stride', 1))) for _, m in modules]}"
210+
len(set(num_from(getattr(m, "stride", 1)) for _, m in modules)) == 1
211+
), f"Expected all modules to have the same stride, but got strides {[(num_from(getattr(m, 'stride', 1))) for _, m in modules]}"
211212

212213
for key, module in modules:
213214
self.add_module(key, module)
@@ -253,11 +254,11 @@ def delay(self) -> int:
253254

254255
@property
255256
def stride(self) -> int:
256-
return int_from(getattr(next(iter(self)), "stride", 1))
257+
return num_from(getattr(next(iter(self)), "stride", 1))
257258

258259
@property
259260
def padding(self) -> int:
260-
return max(int_from(getattr(m, "padding", 0)) for m in self)
261+
return max(num_from(getattr(m, "padding", 0)) for m in self)
261262

262263
def clean_state(self):
263264
for m in self:
@@ -375,12 +376,12 @@ def delay(self):
375376
def stride(self) -> int:
376377
tot = 1
377378
for m in self:
378-
tot *= int_from(getattr(m, "stride", 1))
379+
tot *= num_from(getattr(m, "stride", 1))
379380
return tot
380381

381382
@property
382383
def padding(self) -> int:
383-
return max(int_from(getattr(m, "padding", 0)) for m in self)
384+
return max(num_from(getattr(m, "padding", 0)) for m in self)
384385

385386
@staticmethod
386387
def build_from(module: nn.Sequential) -> "Sequential":
@@ -466,8 +467,8 @@ def __init__(
466467
]
467468

468469
assert (
469-
len(set(int_from(getattr(m, "stride", 1)) for _, m in modules)) == 1
470-
), f"Expected all modules to have the same stride, but got strides {[(int_from(getattr(m, 'stride', 1))) for _, m in modules]}"
470+
len(set(num_from(getattr(m, "stride", 1)) for _, m in modules)) == 1
471+
), f"Expected all modules to have the same stride, but got strides {[(num_from(getattr(m, 'stride', 1))) for _, m in modules]}"
471472

472473
for key, module in modules:
473474
self.add_module(key, module)
@@ -542,11 +543,11 @@ def delay(self) -> int:
542543

543544
@property
544545
def stride(self) -> int:
545-
return int_from(getattr(next(iter(self)), "stride", 1))
546+
return num_from(getattr(next(iter(self)), "stride", 1))
546547

547548
@property
548549
def padding(self) -> int:
549-
return max(int_from(getattr(m, "padding", 0)) for m in self)
550+
return max(num_from(getattr(m, "padding", 0)) for m in self)
550551

551552
def clean_state(self):
552553
for m in self:
@@ -561,14 +562,18 @@ def Residual(
561562
module: CoModule,
562563
temporal_fill: PaddingMode = None,
563564
reduce: Reduction = "sum",
565+
forward_shrink: bool = False,
564566
):
567+
assert num_from(getattr(module, "stride", 1)) == 1, (
568+
"The simple `Residual` only works for modules with temporal stride=1. "
569+
"Complex residuals can be achieved using `BroadcastReduce` or the `Broadcast`, `Parallel`, and `Reduce` modules."
570+
)
571+
temporal_fill = temporal_fill or getattr(
572+
module, "temporal_fill", PaddingMode.REPLICATE.value
573+
)
565574
return BroadcastReduce(
566575
# Residual first yields easier broadcasting in reduce functions
567-
Delay(
568-
delay=module.delay,
569-
temporal_fill=temporal_fill
570-
or getattr(module, "temporal_fill", PaddingMode.REPLICATE.value),
571-
),
576+
Delay(module.delay, temporal_fill, forward_shrink),
572577
module,
573578
reduce=reduce,
574579
auto_delay=False,

continual/delay.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,20 @@ def __init__(
2222
self,
2323
delay: int,
2424
temporal_fill: PaddingMode = "zeros",
25+
forward_shrink: bool = False,
2526
):
27+
"""Initialise Delay block
28+
29+
Args:
30+
delay (int): the number of steps to delay an output.
31+
temporal_fill (PaddingMode, optional): Temporal state initialisation mode ("zeros" or "replicate"). Defaults to "zeros".
32+
forward_shrink (int, optional): Whether to shrink the temporal dimension of the feature map during forward.
33+
This is handy for residuals that are parallel to modules which reduce the number of temporal steps. Defaults to False.
34+
"""
35+
assert delay >= 0
2636
assert temporal_fill in {"zeros", "replicate"}
2737
self._delay = delay
38+
self.forward_shrink = forward_shrink
2839
self.make_padding = {"zeros": torch.zeros_like, "replicate": torch.clone}[
2940
temporal_fill
3041
]
@@ -98,11 +109,14 @@ def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tens
98109

99110
def forward(self, input: Tensor) -> Tensor:
100111
# No delay during regular forward
101-
return input
112+
if not self.forward_shrink or self.delay == 0:
113+
return input
114+
return input[:, :, self.delay : -self.delay]
102115

103116
@property
104117
def delay(self) -> int:
105118
return self._delay
106119

107120
def extra_repr(self):
108-
return f"{self.delay}"
121+
shrink_str = ", forward_shrink=True" if self.forward_shrink else ""
122+
return f"{self.delay}" + shrink_str

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def from_file(file_name: str = "requirements.txt", comment_char: str = "#"):
2525

2626
setup(
2727
name="continual-inference",
28-
version="0.12.0",
28+
version="0.13.0",
2929
description="Building blocks for Continual Inference Networks in PyTorch",
3030
long_description=long_description(),
3131
long_description_content_type="text/markdown",

tests/continual/test_delay.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,14 @@ def test_zero_delay():
100100
def test_repr():
101101
delay = Delay(delay=2)
102102
assert delay.__repr__() == "Delay(2)"
103+
104+
delay = Delay(delay=2, forward_shrink=True)
105+
assert delay.__repr__() == "Delay(2, forward_shrink=True)"
106+
107+
108+
def test_forward_shrink():
109+
sample = torch.rand((2, 2, 5, 3))
110+
delay = Delay(delay=2, forward_shrink=True)
111+
112+
output = delay.forward(sample)
113+
assert torch.equal(sample[:, :, 2:-2], output)

0 commit comments

Comments
 (0)