@@ -446,62 +446,30 @@ def state_append_actions(self,state,actions:Optional[Tensor] = None):
446
446
def get_optimal_actions (
447
447
self ,
448
448
encoded_state ,
449
- return_q_values = False ,
450
449
actions : Optional [Tensor ] = None ,
451
- prob_random_action : float = 0.5 ,
452
- ** kwargs
453
450
):
454
- batch = encoded_state .shape [0 ]
455
-
456
- if prob_random_action == 1 :
457
- return self .get_random_actions (batch )
458
- prob_random_action = - 1
459
- sos_token = encoded_state
460
- tokens = self .maybe_append_actions (sos_token , actions = actions )
461
-
462
- action_bins = []
451
+ batch_size = encoded_state .shape [0 ]
452
+ action_bins = torch .empty (batch_size , self .num_actions , device = encoded_state .device ,dtype = torch .long )
463
453
cache = None
454
+ tokens = self .state_append_actions (encoded_state , actions = actions )
464
455
465
456
for action_idx in range (self .num_actions ):
466
-
467
457
embed , cache = self .transformer (
468
458
tokens ,
469
- context = encoded_state ,
459
+ context = None ,
470
460
cache = cache ,
471
461
return_cache = True
472
462
)
473
-
474
- last_embed = embed [:, action_idx ]
475
- bin_embeddings = self .action_bin_embeddings [action_idx ]
476
-
477
- q_values = einsum ('b d, a d -> b a' , last_embed , bin_embeddings )
478
-
479
- selected_action_bins = q_values .argmax (dim = - 1 )
480
-
481
- if prob_random_action > 0. :
482
- random_mask = torch .zeros_like (selected_action_bins ).float ().uniform_ (0. , 1. ) < prob_random_action
483
- random_actions = self .get_random_actions (batch , 1 )
484
- random_actions = rearrange (random_actions , '... 1 -> ...' )
485
-
486
- selected_action_bins = torch .where (
487
- random_mask ,
488
- random_actions ,
489
- selected_action_bins
490
- )
491
-
492
- next_action_embed = bin_embeddings [selected_action_bins ]
493
-
494
- tokens , _ = pack ((tokens , next_action_embed ), 'b * d' )
495
-
496
- action_bins .append (selected_action_bins )
497
-
498
- action_bins = torch .stack (action_bins , dim = - 1 )
499
-
500
- if not return_q_values :
501
- return action_bins
502
-
503
- all_q_values = self .get_q_values (embed )
504
- return action_bins , all_q_values
463
+ q_values = self .get_q_value_fuction (embed [:, 1 :, :])
464
+ if action_idx == 0 :
465
+ special_idx = action_idx
466
+ else :
467
+ special_idx = action_idx - 1
468
+ _ , selected_action_indices = q_values [:,special_idx ,:].max (dim = - 1 )
469
+ action_bins [:, action_idx ] = selected_action_indices
470
+ now_actions = action_bins [:,0 :action_idx + 1 ]
471
+ tokens = self .state_append_actions (encoded_state , actions = now_actions )
472
+ return action_bins
505
473
506
474
def forward (
507
475
self ,
@@ -585,28 +553,14 @@ def embed_texts(self, texts: List[str]):
585
553
return self .conditioner .embed_texts (texts )
586
554
587
555
@torch .no_grad ()
588
- def get_optimal_actions (
556
+ def get_actions (
589
557
self ,
590
558
state ,
591
- return_q_values = False ,
592
559
actions : Optional [Tensor ] = None ,
593
- ** kwargs
594
560
):
595
561
encoded_state = self .state_encode (state )
596
- return self .q_head .get_optimal_actions (encoded_state , return_q_values = return_q_values , actions = actions )
597
-
598
- def get_actions (
599
- self ,
600
- state ,
601
- prob_random_action = 0. , # otherwise known as epsilon in RL
602
- ** kwargs ,
603
- ):
604
- batch_size = state .shape [0 ]
605
- assert 0. <= prob_random_action <= 1.
562
+ return self .q_head .get_optimal_actions (encoded_state )
606
563
607
- if random () < prob_random_action :
608
- return self .get_random_actions (batch_size = batch_size )
609
- return self .get_optimal_actions (state , ** kwargs )
610
564
611
565
def forward (
612
566
self ,
0 commit comments