Skip to content

Commit 88fec55

Browse files
Merge pull request #60 from LukasHedegaard/develop
Develop
2 parents 8e71ffe + 6d39f3d commit 88fec55

File tree

5 files changed

+25
-10
lines changed

5 files changed

+25
-10
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.
88

99
## Unpublished
1010

11+
12+
## [1.1.2] - 2023-01-13
13+
14+
### Added
15+
- `query_index` argument to `SingleOutputTransformerEncoderLayer`.
16+
17+
### Fixed
18+
- `Residual` centred residual and `Delay` auto_delay forward_step.
19+
20+
1121
## [1.1.1] - 2023-01-10
1222

1323
### Added

continual/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22

3-
__version__ = "1.1.1"
3+
__version__ = "1.1.2"
44
__author__ = "Lukas Hedegaard"
55
__author_email__ = "[email protected]"
66
__license__ = "Apache-2.0"

continual/delay.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def init_state(
6868
) -> State:
6969
padding = self._make_padding(first_output)
7070
state_buffer = torch.stack([padding for _ in range(self.delay)], dim=0)
71-
state_index = torch.tensor(-self.delay)
71+
state_index = torch.tensor(
72+
-2 * self.delay
73+
if self.auto_shrink and isinstance(self.auto_shrink, bool)
74+
else -self.delay
75+
)
7276
return state_buffer, state_index
7377

7478
def clean_state(self):
@@ -113,15 +117,12 @@ def forward_step(self, input: Tensor, update_state=True) -> Tensor:
113117
return CoModule.forward_step(self, input, update_state)
114118

115119
def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
116-
first_run = self.get_state() is None
117120
if self._delay == 0:
118121
return input
119122

120123
with temporary_parameter(self, "padding", (self.delay,)):
121124
output = CoModule.forward_steps(self, input, pad_end, update_state)
122125

123-
if first_run and self.auto_shrink in {True, "centered"}:
124-
output = output[:, :, self.delay :]
125126
return output
126127

127128
def forward(self, input: Tensor) -> Tensor:

continual/transformer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def SingleOutputTransformerEncoderLayer(
164164
dtype=None,
165165
sequence_len: int = None,
166166
single_output_forward=False,
167+
query_index: int = -1,
167168
):
168169
"""Continual Single-output Transformer Encoder layer.
169170
@@ -191,6 +192,7 @@ def SingleOutputTransformerEncoderLayer(
191192
dtype: datatype of layer parameters. Defaults to None.
192193
sequence_len: length of token-sequence to perform attention across. Defaults to None.
193194
single_output_forward: whether to restrict the attention to the last token during forward. Defaults to False.
195+
query_index: the sequence position index to compute the attention for.
194196
195197
Examples::
196198
@@ -225,7 +227,7 @@ def SingleOutputTransformerEncoderLayer(
225227
bias=True,
226228
batch_first=True,
227229
embed_dim_second=True,
228-
query_index=-1,
230+
query_index=query_index,
229231
device=device,
230232
dtype=dtype,
231233
sequence_len=sequence_len,
@@ -462,7 +464,7 @@ def TransformerEncoderLayerFactory(
462464
463465
Examples::
464466
465-
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8)
467+
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8, sequence_len=32)
466468
transformer_encoder = co.TransformerEncoder(encoder_layer, num_layers=2)
467469
src = torch.rand(10, 512, 32)
468470
out = transformer_encoder(src)
@@ -527,7 +529,7 @@ class TransformerEncoder(Sequential):
527529
528530
Examples::
529531
530-
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8)
532+
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8, sequence_len=32)
531533
transformer_encoder = co.TransformerEncoder(encoder_layer, num_layers=2)
532534
src = torch.rand(10, 512, 32)
533535
out = transformer_encoder(src)

tests/continual/test_container.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ def test_residual_shrink_centered():
280280

281281
# forward_steps
282282
co_res.clean_state()
283-
out_firsts = co_res.forward_steps(input[:, :, :-1], pad_end=False)
283+
_ = co_res.forward_step(input[:, :, 0])
284+
out_firsts = co_res.forward_steps(input[:, :, 1:-1], pad_end=False)
284285
assert torch.allclose(out_firsts, target[:, :, :3])
285286

286287
# forward_step
@@ -312,7 +313,8 @@ def test_residual_shrink_lagging():
312313

313314
# forward_steps
314315
co_res.clean_state()
315-
out_firsts = co_res.forward_steps(input[:, :, :-1], pad_end=False)
316+
_ = co_res.forward_step(input[:, :, 0])
317+
out_firsts = co_res.forward_steps(input[:, :, 1:-1], pad_end=False)
316318
assert torch.allclose(out_firsts, out_manual_res[:, :, :3])
317319

318320
# forward_step

0 commit comments

Comments
 (0)