Skip to content

Commit c4ae37f

Browse files
Merge pull request #34 from LukasHedegaard/develop
Receptive field, Phantom padding, Reshape, Lambda update and various fixes
2 parents ff66157 + 79f7468 commit c4ae37f

21 files changed

+337
-101
lines changed

CHANGELOG.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,26 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.
1010
## [Unreleased]
1111

1212

13+
## [0.14.0]
14+
### Added
15+
- Added `phantom_padding` to `Residual`.
16+
- Added `receptive_field` property.
17+
- Added `Reshape` module.
18+
19+
### Changed
20+
- Rename `forward_shrink` argument to `auto_shrink` in `Delay`.
21+
- Torch requirement to v1.9.
22+
- Replace `Lambda` unsqueeze_step with takes_time and new default to False.
23+
24+
## Fixed
25+
- `padding` property in sequence.
26+
- `delay` property in sequence.
27+
- `strict` mode in `load_state_dict`.
28+
29+
## Removed
30+
- Assertion error in `BroadcastReduce` for modules with different delays.
31+
32+
1333
## [0.13.0]
1434
### Added
1535
- Add `forward_shrink` option to `Delay` and `Residual`.

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ last = conv.forward_step(example[:, :, 4])
5252

5353
assert torch.allclose(output[:, :, : conv.delay], firsts)
5454
assert torch.allclose(output[:, :, conv.delay], last)
55+
56+
# Temporal properties
57+
assert conv.receptive_field == 3
58+
assert conv.delay == 2
5559
```
5660

5761
See the "Advanced Examples" section for additional examples..
@@ -211,6 +215,7 @@ Below is a list of the modules and utilities included in the library:
211215

212216
- Other
213217
- `co.Delay` - Pure delay module (e.g. needed in residuals).
218+
- `co.Reshape` - Reshape non-temporal dimensions.
214219
- `co.Lambda` - Lambda module which wraps any function.
215220
- `co.Add` - Adds a constant value.
216221
- `co.Multiply` - Multiplies with a constant factor.
@@ -361,7 +366,7 @@ inception_module = co.BroadcastReduce(
361366
),
362367
co.Sequential(
363368
norm_relu(co.Conv3d(192, 16, kernel_size=1), 16),
364-
norm_relu(co.Conv3d(16, 32, kernel_size=3, padding=1), 32),
369+
norm_relu(co.Conv3d(16, 32, kernel_size=5, padding=2), 32),
365370
),
366371
co.Sequential(
367372
co.MaxPool3d(kernel_size=(1, 3, 3), padding=(0, 1, 1), stride=1),

continual/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@
2626
MaxPool3d,
2727
)
2828
from .ptflops import _register_ptflops # noqa: F401
29+
from .shape import Reshape # noqa: F401
2930
from .utils import flat_state_dict, load_state_dict, state_dict # noqa: F401

continual/closure.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,20 @@
99

1010

1111
class Lambda(CoModule, nn.Module):
12-
"""Module wrapper for stateless functions
12+
"""Module wrapper for stateless functions.
1313
1414
NB: Operations performed in a Lambda are not counted in `ptflops`
15+
16+
Args:
17+
fn (Callable[[Tensor], Tensor]): Function to be called during forward.
18+
takes_time (bool, optional): If True, `fn` recieves all steps, if False, it received one step and no time dimension. Defaults to False.
1519
"""
1620

17-
def __init__(self, fn: Callable[[Tensor], Tensor], unsqueeze_step=True):
21+
def __init__(self, fn: Callable[[Tensor], Tensor], takes_time=False):
1822
nn.Module.__init__(self)
1923
assert callable(fn), "The pased function should be callable."
2024
self.fn = fn
21-
self.unsqueeze_step = unsqueeze_step
25+
self.takes_time = takes_time
2226

2327
def __repr__(self) -> str:
2428
s = self.fn.__name__
@@ -47,26 +51,27 @@ def __repr__(self) -> str:
4751
return f"Lambda({s})"
4852

4953
def forward(self, input: Tensor) -> Tensor:
50-
return self.fn(input)
54+
if self.takes_time:
55+
return self.fn(input)
56+
57+
return torch.stack(
58+
[self.fn(input[:, :, t]) for t in range(input.shape[2])], dim=2
59+
)
60+
61+
def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
62+
return self.forward(input)
5163

5264
def forward_step(self, input: Tensor, update_state=True) -> Tensor:
53-
if self.unsqueeze_step:
65+
if self.takes_time:
5466
input = input.unsqueeze(dim=2)
5567
output = self.fn(input)
56-
if self.unsqueeze_step:
68+
if self.takes_time:
5769
output = output.squeeze(dim=2)
5870
return output
5971

60-
def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
61-
return self.fn(input)
62-
63-
@property
64-
def delay(self) -> int:
65-
return 0
66-
6772
@staticmethod
68-
def build_from(fn: Callable[[Tensor], Tensor]) -> "Lambda":
69-
return Lambda(fn)
73+
def build_from(fn: Callable[[Tensor], Tensor], takes_time=False) -> "Lambda":
74+
return Lambda(fn, takes_time)
7075

7176

7277
def _multiply(x: Tensor, factor: Union[float, int, Tensor]):
@@ -76,7 +81,7 @@ def _multiply(x: Tensor, factor: Union[float, int, Tensor]):
7681
def Multiply(factor) -> Lambda:
7782
"""Create Lambda with multiplication function"""
7883
fn = partial(_multiply, factor=factor)
79-
return Lambda(fn)
84+
return Lambda(fn, takes_time=True)
8085

8186

8287
def _add(x: Tensor, constant: Union[float, int, Tensor]):
@@ -86,7 +91,7 @@ def _add(x: Tensor, constant: Union[float, int, Tensor]):
8691
def Add(constant) -> Lambda:
8792
"""Create Lambda with addition function"""
8893
fn = partial(_add, constant=constant)
89-
return Lambda(fn)
94+
return Lambda(fn, takes_time=True)
9095

9196

9297
def _unity(x: Tensor):
@@ -95,18 +100,18 @@ def _unity(x: Tensor):
95100

96101
def Unity() -> Lambda:
97102
"""Create Lambda with addition function"""
98-
return Lambda(_unity)
103+
return Lambda(_unity, takes_time=True)
99104

100105

101106
def Constant(constant: float):
102-
return Lambda(lambda x: constant * torch.ones_like(x))
107+
return Lambda(lambda x: constant * torch.ones_like(x), takes_time=True)
103108

104109

105110
def Zero() -> Lambda:
106111
"""Create Lambda with zero output"""
107-
return Lambda(torch.zeros_like)
112+
return Lambda(torch.zeros_like, takes_time=True)
108113

109114

110115
def One() -> Lambda:
111116
"""Create Lambda with zero output"""
112-
return Lambda(torch.ones_like)
117+
return Lambda(torch.ones_like, takes_time=True)

0 commit comments

Comments
 (0)