Skip to content

Commit 9aa23ac

Browse files
Merge pull request #32 from LukasHedegaard/develop
Add Constant, Zero, and One
2 parents 9e73906 + b78b20c commit 9aa23ac

File tree

5 files changed

+40
-3
lines changed

5 files changed

+40
-3
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.
99

1010
## [Unreleased]
1111

12+
## [0.12.0]
13+
### Added
14+
- Add `Constant`.
15+
- Add `Zero`.
16+
- Add `One`.
17+
1218

1319
## [0.11.4]
1420
### Fixed

continual/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .closure import Add, Lambda, Multiply, Unity # noqa: F401
1+
from .closure import Add, Constant, Lambda, Multiply, One, Unity, Zero # noqa: F401
22
from .container import ( # noqa: F401
33
Broadcast,
44
BroadcastReduce,

continual/closure.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from inspect import getsource
33
from typing import Callable, Union
44

5+
import torch
56
from torch import Tensor, nn
67

78
from .module import CoModule
@@ -95,3 +96,17 @@ def _unity(x: Tensor):
9596
def Unity() -> Lambda:
9697
"""Create Lambda with addition function"""
9798
return Lambda(_unity)
99+
100+
101+
def Constant(constant: float):
102+
return Lambda(lambda x: constant * torch.ones_like(x))
103+
104+
105+
def Zero() -> Lambda:
106+
"""Create Lambda with zero output"""
107+
return Lambda(torch.zeros_like)
108+
109+
110+
def One() -> Lambda:
111+
"""Create Lambda with zero output"""
112+
return Lambda(torch.ones_like)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def from_file(file_name: str = "requirements.txt", comment_char: str = "#"):
2525

2626
setup(
2727
name="continual-inference",
28-
version="0.11.4",
28+
version="0.12.0",
2929
description="Building blocks for Continual Inference Networks in PyTorch",
3030
long_description=long_description(),
3131
long_description_content_type="text/markdown",

tests/continual/test_closure.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from continual.closure import Add, Lambda, Multiply, Unity
3+
from continual.closure import Add, Constant, Lambda, Multiply, One, Unity, Zero
44

55

66
def test_add():
@@ -76,3 +76,19 @@ def local_always42(x):
7676
def test_unity():
7777
x = torch.ones((1, 1, 2, 2))
7878
assert torch.equal(x, Unity()(x))
79+
80+
81+
def test_constant():
82+
x = torch.randn((1, 1, 2, 2))
83+
const = 42
84+
assert torch.equal(const * torch.ones_like(x), Constant(const)(x))
85+
86+
87+
def test_zero():
88+
x = torch.randn((1, 1, 2, 2))
89+
assert torch.equal(torch.zeros_like(x), Zero()(x))
90+
91+
92+
def test_one():
93+
x = torch.randn((1, 1, 2, 2))
94+
assert torch.equal(torch.ones_like(x), One()(x))

0 commit comments

Comments
 (0)