Skip to content

Commit 8d71f96

Browse files
author
xindong.he
committed
fix(yzj):fix device bug
1 parent 6ea3f9b commit 8d71f96

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

lzero/policy/efficientzero.py

+1
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
401401
# obtain the oracle latent states from representation function.
402402
beg_index, end_index = self._get_target_obs_index_in_step_k(step_k)
403403
obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze())
404+
obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device)
404405
network_output = self._learn_model.initial_inference(obs_target_batch_tmp)
405406

406407
latent_state = to_tensor(latent_state)

lzero/policy/muzero.py

+1
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
369369
# obtain the oracle latent states from representation function.
370370
beg_index, end_index = self._get_target_obs_index_in_step_k(step_k)
371371
obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze())
372+
obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device)
372373
network_output = self._learn_model.initial_inference(obs_target_batch_tmp)
373374

374375
latent_state = to_tensor(latent_state)

0 commit comments

Comments
 (0)