@@ -143,7 +143,7 @@ def compute_estimate(
143
143
if self ._settings .use_actions :
144
144
actions = self .get_action_input (mini_batch )
145
145
dones = torch .as_tensor (
146
- mini_batch [BufferKey .DONE ], dtype = torch .float
146
+ mini_batch [BufferKey .DONE ], dtype = torch .float , device = default_device ()
147
147
).unsqueeze (1 )
148
148
action_inputs = torch .cat ([actions , dones ], dim = 1 )
149
149
hidden , _ = self .encoder (inputs , action_inputs )
@@ -162,7 +162,7 @@ def compute_loss(
162
162
"""
163
163
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
164
164
"""
165
- total_loss = torch .zeros (1 )
165
+ total_loss = torch .zeros (1 , device = default_device () )
166
166
stats_dict : Dict [str , np .ndarray ] = {}
167
167
policy_estimate , policy_mu = self .compute_estimate (
168
168
policy_batch , use_vail_noise = True
@@ -219,21 +219,23 @@ def compute_gradient_magnitude(
219
219
expert_inputs = self .get_state_inputs (expert_batch )
220
220
interp_inputs = []
221
221
for policy_input , expert_input in zip (policy_inputs , expert_inputs ):
222
- obs_epsilon = torch .rand (policy_input .shape )
222
+ obs_epsilon = torch .rand (policy_input .shape , device = policy_input . device )
223
223
interp_input = obs_epsilon * policy_input + (1 - obs_epsilon ) * expert_input
224
224
interp_input .requires_grad = True # For gradient calculation
225
225
interp_inputs .append (interp_input )
226
226
if self ._settings .use_actions :
227
227
policy_action = self .get_action_input (policy_batch )
228
228
expert_action = self .get_action_input (expert_batch )
229
- action_epsilon = torch .rand (policy_action .shape )
229
+ action_epsilon = torch .rand (
230
+ policy_action .shape , device = policy_action .device
231
+ )
230
232
policy_dones = torch .as_tensor (
231
- policy_batch [BufferKey .DONE ], dtype = torch .float
233
+ policy_batch [BufferKey .DONE ], dtype = torch .float , device = default_device ()
232
234
).unsqueeze (1 )
233
235
expert_dones = torch .as_tensor (
234
- expert_batch [BufferKey .DONE ], dtype = torch .float
236
+ expert_batch [BufferKey .DONE ], dtype = torch .float , device = default_device ()
235
237
).unsqueeze (1 )
236
- dones_epsilon = torch .rand (policy_dones .shape )
238
+ dones_epsilon = torch .rand (policy_dones .shape , device = policy_dones . device )
237
239
action_inputs = torch .cat (
238
240
[
239
241
action_epsilon * policy_action
0 commit comments