Skip to content

Commit a83b3b8

Browse files
bmind7maryamziaa
andauthored
[Bugfix] Fix CUDA/CPU mismatch in threaded training (#6245)
* Ensure tensors use default device in torch policy and utils --------- Co-authored-by: maryam-zia <[email protected]>
1 parent a277771 commit a83b3b8

File tree

6 files changed

+34
-22
lines changed

6 files changed

+34
-22
lines changed

ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Dict, Optional, Tuple, List
2-
from mlagents.torch_utils import torch
2+
from mlagents.torch_utils import torch, default_device
33
import numpy as np
44
from collections import defaultdict
55

@@ -162,7 +162,7 @@ def get_trajectory_value_estimates(
162162
memory = self.critic_memory_dict[agent_id]
163163
else:
164164
memory = (
165-
torch.zeros((1, 1, self.critic.memory_size))
165+
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
166166
if self.policy.use_recurrent
167167
else None
168168
)

ml-agents/mlagents/trainers/poca/optimizer_torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,12 +608,12 @@ def get_trajectory_and_baseline_value_estimates(
608608
_init_baseline_mem = self.baseline_memory_dict[agent_id]
609609
else:
610610
_init_value_mem = (
611-
torch.zeros((1, 1, self.critic.memory_size))
611+
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
612612
if self.policy.use_recurrent
613613
else None
614614
)
615615
_init_baseline_mem = (
616-
torch.zeros((1, 1, self.critic.memory_size))
616+
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
617617
if self.policy.use_recurrent
618618
else None
619619
)

ml-agents/mlagents/trainers/policy/torch_policy.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,17 @@ def export_memory_size(self) -> int:
6969
return self._export_m_size
7070

7171
def _extract_masks(self, decision_requests: DecisionSteps) -> np.ndarray:
72+
device = default_device()
7273
mask = None
7374
if self.behavior_spec.action_spec.discrete_size > 0:
7475
num_discrete_flat = np.sum(self.behavior_spec.action_spec.discrete_branches)
75-
mask = torch.ones([len(decision_requests), num_discrete_flat])
76+
mask = torch.ones(
77+
[len(decision_requests), num_discrete_flat], device=device
78+
)
7679
if decision_requests.action_mask is not None:
7780
mask = torch.as_tensor(
78-
1 - np.concatenate(decision_requests.action_mask, axis=1)
81+
1 - np.concatenate(decision_requests.action_mask, axis=1),
82+
device=device,
7983
)
8084
return mask
8185

@@ -91,11 +95,12 @@ def evaluate(
9195
"""
9296
obs = decision_requests.obs
9397
masks = self._extract_masks(decision_requests)
94-
tensor_obs = [torch.as_tensor(np_ob) for np_ob in obs]
98+
device = default_device()
99+
tensor_obs = [torch.as_tensor(np_ob, device=device) for np_ob in obs]
95100

96-
memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze(
97-
0
98-
)
101+
memories = torch.as_tensor(
102+
self.retrieve_memories(global_agent_ids), device=device
103+
).unsqueeze(0)
99104
with torch.no_grad():
100105
action, run_out, memories = self.actor.get_action_and_stats(
101106
tensor_obs, masks=masks, memories=memories

ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def compute_estimate(
143143
if self._settings.use_actions:
144144
actions = self.get_action_input(mini_batch)
145145
dones = torch.as_tensor(
146-
mini_batch[BufferKey.DONE], dtype=torch.float
146+
mini_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
147147
).unsqueeze(1)
148148
action_inputs = torch.cat([actions, dones], dim=1)
149149
hidden, _ = self.encoder(inputs, action_inputs)
@@ -162,7 +162,7 @@ def compute_loss(
162162
"""
163163
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
164164
"""
165-
total_loss = torch.zeros(1)
165+
total_loss = torch.zeros(1, device=default_device())
166166
stats_dict: Dict[str, np.ndarray] = {}
167167
policy_estimate, policy_mu = self.compute_estimate(
168168
policy_batch, use_vail_noise=True
@@ -219,21 +219,23 @@ def compute_gradient_magnitude(
219219
expert_inputs = self.get_state_inputs(expert_batch)
220220
interp_inputs = []
221221
for policy_input, expert_input in zip(policy_inputs, expert_inputs):
222-
obs_epsilon = torch.rand(policy_input.shape)
222+
obs_epsilon = torch.rand(policy_input.shape, device=policy_input.device)
223223
interp_input = obs_epsilon * policy_input + (1 - obs_epsilon) * expert_input
224224
interp_input.requires_grad = True # For gradient calculation
225225
interp_inputs.append(interp_input)
226226
if self._settings.use_actions:
227227
policy_action = self.get_action_input(policy_batch)
228228
expert_action = self.get_action_input(expert_batch)
229-
action_epsilon = torch.rand(policy_action.shape)
229+
action_epsilon = torch.rand(
230+
policy_action.shape, device=policy_action.device
231+
)
230232
policy_dones = torch.as_tensor(
231-
policy_batch[BufferKey.DONE], dtype=torch.float
233+
policy_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
232234
).unsqueeze(1)
233235
expert_dones = torch.as_tensor(
234-
expert_batch[BufferKey.DONE], dtype=torch.float
236+
expert_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
235237
).unsqueeze(1)
236-
dones_epsilon = torch.rand(policy_dones.shape)
238+
dones_epsilon = torch.rand(policy_dones.shape, device=policy_dones.device)
237239
action_inputs = torch.cat(
238240
[
239241
action_epsilon * policy_action

ml-agents/mlagents/trainers/torch_entities/networks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Callable, List, Dict, Tuple, Optional, Union, Any
22
import abc
33

4-
from mlagents.torch_utils import torch, nn
4+
from mlagents.torch_utils import torch, nn, default_device
55

66
from mlagents_envs.base_env import ActionSpec, ObservationSpec, ObservationType
77
from mlagents.trainers.torch_entities.action_model import ActionModel
@@ -87,7 +87,9 @@ def update_normalization(self, buffer: AgentBuffer) -> None:
8787
obs = ObsUtil.from_buffer(buffer, len(self.processors))
8888
for vec_input, enc in zip(obs, self.processors):
8989
if isinstance(enc, VectorInput):
90-
enc.update_normalization(torch.as_tensor(vec_input.to_ndarray()))
90+
enc.update_normalization(
91+
torch.as_tensor(vec_input.to_ndarray(), device=default_device())
92+
)
9193

9294
def copy_normalization(self, other_encoder: "ObservationEncoder") -> None:
9395
if self.normalize:

ml-agents/mlagents/trainers/torch_entities/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import List, Optional, Tuple, Dict
2-
from mlagents.torch_utils import torch, nn
2+
from mlagents.torch_utils import torch, nn, default_device
33
from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization
44
import numpy as np
55

@@ -233,7 +233,8 @@ def list_to_tensor(
233233
Converts a list of numpy arrays into a tensor. MUCH faster than
234234
calling as_tensor on the list directly.
235235
"""
236-
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)
236+
device = default_device()
237+
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype, device=device)
237238

238239
@staticmethod
239240
def list_to_tensor_list(
@@ -243,8 +244,10 @@ def list_to_tensor_list(
243244
Converts a list of numpy arrays into a list of tensors. MUCH faster than
244245
calling as_tensor on the list directly.
245246
"""
247+
device = default_device()
246248
return [
247-
torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list
249+
torch.as_tensor(np.asanyarray(_arr), dtype=dtype, device=device)
250+
for _arr in ndarray_list
248251
]
249252

250253
@staticmethod

0 commit comments

Comments
 (0)