Skip to content

Commit 5988d14

Browse files
committed
good use for eval
1 parent 066ff45 commit 5988d14

File tree

2 files changed

+19
-65
lines changed

2 files changed

+19
-65
lines changed

ding/model/template/qtransformer.py

+16-62
Original file line numberDiff line numberDiff line change
@@ -446,62 +446,30 @@ def state_append_actions(self,state,actions:Optional[Tensor] = None):
446446
def get_optimal_actions(
447447
self,
448448
encoded_state,
449-
return_q_values = False,
450449
actions: Optional[Tensor] = None,
451-
prob_random_action: float = 0.5,
452-
**kwargs
453450
):
454-
batch = encoded_state.shape[0]
455-
456-
if prob_random_action == 1:
457-
return self.get_random_actions(batch)
458-
prob_random_action = -1
459-
sos_token = encoded_state
460-
tokens = self.maybe_append_actions(sos_token, actions = actions)
461-
462-
action_bins = []
451+
batch_size = encoded_state.shape[0]
452+
action_bins = torch.empty(batch_size, self.num_actions, device=encoded_state.device,dtype=torch.long)
463453
cache = None
454+
tokens = self.state_append_actions(encoded_state, actions = actions)
464455

465456
for action_idx in range(self.num_actions):
466-
467457
embed, cache = self.transformer(
468458
tokens,
469-
context = encoded_state,
459+
context = None,
470460
cache = cache,
471461
return_cache = True
472462
)
473-
474-
last_embed = embed[:, action_idx]
475-
bin_embeddings = self.action_bin_embeddings[action_idx]
476-
477-
q_values = einsum('b d, a d -> b a', last_embed, bin_embeddings)
478-
479-
selected_action_bins = q_values.argmax(dim = -1)
480-
481-
if prob_random_action > 0.:
482-
random_mask = torch.zeros_like(selected_action_bins).float().uniform_(0., 1.) < prob_random_action
483-
random_actions = self.get_random_actions(batch, 1)
484-
random_actions = rearrange(random_actions, '... 1 -> ...')
485-
486-
selected_action_bins = torch.where(
487-
random_mask,
488-
random_actions,
489-
selected_action_bins
490-
)
491-
492-
next_action_embed = bin_embeddings[selected_action_bins]
493-
494-
tokens, _ = pack((tokens, next_action_embed), 'b * d')
495-
496-
action_bins.append(selected_action_bins)
497-
498-
action_bins = torch.stack(action_bins, dim = -1)
499-
500-
if not return_q_values:
501-
return action_bins
502-
503-
all_q_values = self.get_q_values(embed)
504-
return action_bins, all_q_values
463+
q_values = self.get_q_value_fuction(embed[:, 1:, :])
464+
if action_idx ==0 :
465+
special_idx=action_idx
466+
else :
467+
special_idx=action_idx-1
468+
_, selected_action_indices = q_values[:,special_idx,:].max(dim=-1)
469+
action_bins[:, action_idx] = selected_action_indices
470+
now_actions=action_bins[:,0:action_idx+1]
471+
tokens = self.state_append_actions(encoded_state, actions = now_actions)
472+
return action_bins
505473

506474
def forward(
507475
self,
@@ -585,28 +553,14 @@ def embed_texts(self, texts: List[str]):
585553
return self.conditioner.embed_texts(texts)
586554

587555
@torch.no_grad()
588-
def get_optimal_actions(
556+
def get_actions(
589557
self,
590558
state,
591-
return_q_values = False,
592559
actions: Optional[Tensor] = None,
593-
**kwargs
594560
):
595561
encoded_state = self.state_encode(state)
596-
return self.q_head.get_optimal_actions(encoded_state, return_q_values = return_q_values, actions = actions)
597-
598-
def get_actions(
599-
self,
600-
state,
601-
prob_random_action = 0., # otherwise known as epsilon in RL
602-
**kwargs,
603-
):
604-
batch_size = state.shape[0]
605-
assert 0. <= prob_random_action <= 1.
562+
return self.q_head.get_optimal_actions(encoded_state)
606563

607-
if random() < prob_random_action:
608-
return self.get_random_actions(batch_size = batch_size)
609-
return self.get_optimal_actions(state, **kwargs)
610564

611565
def forward(
612566
self,

ding/policy/qtransformer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def _discretize_action(self, actions):
414414

415415
def _get_actions(self, obs):
416416
# evaluate to get action
417-
action = self._eval_model.get_optimal_actions(obs)
417+
action = self._target_model.get_actions(obs)
418418
action = 2*action/256.0-1
419419
return action
420420

@@ -442,8 +442,8 @@ def _state_dict_learn(self) -> Dict[str, Any]:
442442
- state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
443443
"""
444444
ret = {
445-
'model': self._model.state_dict(),
446-
'ema_model': self._ema_model.state_dict(),
445+
'model': self._learn_model.state_dict(),
446+
'ema_model': self._target_model.state_dict(),
447447
'optimizer_q': self._optimizer_q.state_dict(),
448448
}
449449
if self._auto_alpha:

0 commit comments

Comments
 (0)