Skip to content

Commit 7d7f495

Browse files
Merge pull request #68 from LukasHedegaard/develop
Fix state_index device after clean_state
2 parents 63fd064 + a189713 commit 7d7f495

File tree

7 files changed

+19
-12
lines changed

7 files changed

+19
-12
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.
88

99
## Unpublished
1010

11+
12+
## [1.2.3] - 2023-06-16
13+
14+
### Fixed
15+
- Ensure state_index remains on the same device after clean_state.
16+
17+
1118
## [1.2.2] - 2023-05-24
1219

1320
### Fixed

continual/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22

3-
__version__ = "1.2.2"
3+
__version__ = "1.2.3"
44
__author__ = "Lukas Hedegaard"
55
__author_email__ = "[email protected]"
66
__license__ = "Apache-2.0"

continual/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def init_state(
155155
return (state_buffer, state_index, stride_index)
156156

157157
def clean_state(self):
158-
self.state_buffer = torch.tensor([])
158+
self.state_buffer = torch.tensor([], device=self.state_buffer.device)
159159
self.state_index = torch.tensor(0)
160160
self.stride_index = torch.tensor(0)
161161

continual/delay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def init_state(
7676
return state_buffer, state_index
7777

7878
def clean_state(self):
79-
self.state_buffer = torch.tensor([])
79+
self.state_buffer = torch.tensor([], device=self.state_buffer.device)
8080
self.state_index = torch.tensor(0)
8181

8282
def get_state(self):

continual/multihead_attention/retroactive_mha.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,11 @@ def set_state(self, state: State):
276276
) = state
277277

278278
def clean_state(self):
279-
self.d_mem = torch.tensor([])
280-
self.AV_mem = torch.tensor([])
281-
self.Q_mem = torch.tensor([])
282-
self.K_T_mem = torch.tensor([])
283-
self.V_mem = torch.tensor([])
279+
self.d_mem = torch.tensor([], device=self.d_mem.device)
280+
self.AV_mem = torch.tensor([], device=self.AV_mem.device)
281+
self.Q_mem = torch.tensor([], device=self.Q_mem.device)
282+
self.K_T_mem = torch.tensor([], device=self.K_T_mem.device)
283+
self.V_mem = torch.tensor([], device=self.V_mem.device)
284284
self.stride_index = torch.tensor(0)
285285

286286
def _forward_step(

continual/multihead_attention/single_output_mha.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,9 @@ def set_state(self, state: State):
252252
) = state
253253

254254
def clean_state(self):
255-
self.Q_mem = torch.tensor([])
256-
self.K_T_mem = torch.tensor([])
257-
self.V_mem = torch.tensor([])
255+
self.Q_mem = torch.tensor([], device=self.Q_mem.device)
256+
self.K_T_mem = torch.tensor([], device=self.K_T_mem.device)
257+
self.V_mem = torch.tensor([], device=self.V_mem.device)
258258
self.stride_index = torch.tensor(0)
259259

260260
@property

continual/pooling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def init_state(
166166
return state_buffer, state_index, stride_index
167167

168168
def clean_state(self):
169-
self.state_buffer = torch.tensor([])
169+
self.state_buffer = torch.tensor([], device=self.state_buffer.device)
170170
self.state_index = torch.tensor(0)
171171
self.stride_index = torch.tensor(0)
172172

0 commit comments

Comments
 (0)