Skip to content

Commit b78b20c

Browse files
Merge branch 'develop' of https://github.com/LukasHedegaard/continual-inference into develop
2 parents f7e733e + 9e73906 commit b78b20c

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

CHANGELOG.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,18 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.
99

1010
## [Unreleased]
1111

12-
13-
## [0.11.4]
12+
## [0.12.0]
1413
### Added
1514
- Add `Constant`.
1615
- Add `Zero`.
1716
- Add `One`.
1817

1918

19+
## [0.11.4]
20+
### Fixed
21+
- `co.ConvXd` cuda compatibility.
22+
23+
2024
## [0.11.3]
2125
### Added
2226
- Add `flatten_state_dict` state variable.

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ clean:
3636
## Test the setup
3737
test:
3838
@echo ⚡⚡⚡ Testing ⚡⚡⚡
39-
py.test --cov continual --cov-report term-missing
39+
python -m pytest --cov continual --cov-report term-missing
4040

4141

4242
## Upload to codecov.io

continual/conv.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def __init__(
4545
groups: int = 1,
4646
bias: bool = True,
4747
padding_mode: PaddingMode = "zeros",
48+
device=None,
49+
dtype=None,
4850
temporal_fill: PaddingMode = "zeros",
4951
):
5052
assert issubclass(
@@ -82,6 +84,8 @@ def __init__(
8284
groups=groups,
8385
bias=bias,
8486
padding_mode=padding_mode,
87+
device=device,
88+
dtype=dtype,
8589
)
8690
self.make_padding = {
8791
PaddingMode.ZEROS.value: torch.zeros_like,
@@ -142,7 +146,7 @@ def _forward_step(self, input: Tensor, prev_state: State) -> Tuple[Tensor, State
142146
), f"A tensor of shape {(*self.input_shape_desciption[:2], *self.input_shape_desciption[3:])} should be passed as input but got {input.shape}"
143147

144148
# e.g. B, C -> B, C, 1
145-
x = input.unsqueeze(2)
149+
x = input.unsqueeze(2).to(device=self.weight.device)
146150

147151
if self.padding_mode == "zeros":
148152
x = self._conv_func(
@@ -239,6 +243,8 @@ def __init__(
239243
groups: int = 1,
240244
bias: bool = True,
241245
padding_mode: PaddingMode = "zeros",
246+
device=None,
247+
dtype=None,
242248
temporal_fill: PaddingMode = "zeros",
243249
):
244250
r"""Applies a continual 1D convolution over an input signal composed of several input
@@ -295,6 +301,8 @@ def __init__(
295301
groups,
296302
bias,
297303
padding_mode,
304+
device,
305+
dtype,
298306
temporal_fill,
299307
)
300308

@@ -338,6 +346,8 @@ def __init__(
338346
groups: int = 1,
339347
bias: bool = True,
340348
padding_mode: PaddingMode = "zeros",
349+
device=None,
350+
dtype=None,
341351
temporal_fill: PaddingMode = "zeros",
342352
):
343353
r"""Applies a continual 2D convolution over an input signal composed of several input
@@ -394,6 +404,8 @@ def __init__(
394404
groups,
395405
bias,
396406
padding_mode,
407+
device,
408+
dtype,
397409
temporal_fill,
398410
)
399411

@@ -437,6 +449,8 @@ def __init__(
437449
groups: int = 1,
438450
bias: bool = True,
439451
padding_mode: PaddingMode = "zeros",
452+
device=None,
453+
dtype=None,
440454
temporal_fill: PaddingMode = "zeros",
441455
):
442456
r"""Applies a continual 3D convolution over an input signal composed of several input
@@ -495,6 +509,8 @@ def __init__(
495509
groups,
496510
bias,
497511
padding_mode,
512+
device,
513+
dtype,
498514
temporal_fill,
499515
)
500516

0 commit comments

Comments
 (0)