1
1
import collections
2
- from typing import List
3
2
4
- import numpy as np
5
3
import torch
6
4
7
5
from genrl .agents .deep .dqn .base import DQN
@@ -64,25 +62,27 @@ def prioritized_q_loss(agent: DQN, batch: collections.namedtuple):
64
62
return loss
65
63
66
64
67
- def categorical_greedy_action (agent : DQN , state : torch .Tensor ) -> np . ndarray :
65
+ def categorical_greedy_action (agent : DQN , state : torch .Tensor ) -> torch . Tensor :
68
66
"""Greedy action selection for Categorical DQN
69
67
70
68
Args:
71
69
agent (:obj:`DQN`): The agent
72
- state (:obj:`np.ndarray `): Current state of the environment
70
+ state (:obj:`torch.Tensor `): Current state of the environment
73
71
74
72
Returns:
75
- action (:obj:`np.ndarray `): Action taken by the agent
73
+ action (:obj:`torch.Tensor `): Action taken by the agent
76
74
"""
77
- q_value_dist = agent .model (state .unsqueeze (0 )).detach ().numpy ()
75
+ q_value_dist = agent .model (state .unsqueeze (0 )).detach () # .numpy()
78
76
# We need to scale and discretise the Q-value distribution obtained above
79
- q_value_dist = q_value_dist * np .linspace (agent .v_min , agent .v_max , agent .num_atoms )
77
+ q_value_dist = q_value_dist * torch .linspace (
78
+ agent .v_min , agent .v_max , agent .num_atoms
79
+ )
80
80
# Then we find the action with the highest Q-values for all discrete regions
81
81
# Current shape of the q_value_dist is [1, n_envs, action_dim, num_atoms]
82
82
# So we take the sum of all the individual atom q_values and then take argmax
83
83
# along action dim to get the optimal action. Since batch_size is 1 for this
84
84
# function, we squeeze the first dimension out.
85
- action = np .argmax (q_value_dist .sum (- 1 ), axis = - 1 ).squeeze (0 )
85
+ action = torch .argmax (q_value_dist .sum (- 1 ), axis = - 1 ).squeeze (0 )
86
86
return action
87
87
88
88
@@ -119,9 +119,9 @@ def categorical_q_values(agent: DQN, states: torch.Tensor, actions: torch.Tensor
119
119
120
120
def categorical_q_target (
121
121
agent : DQN ,
122
- next_states : np . ndarray ,
123
- rewards : List [ float ] ,
124
- dones : List [ bool ] ,
122
+ next_states : torch . Tensor ,
123
+ rewards : torch . Tensor ,
124
+ dones : torch . Tensor ,
125
125
):
126
126
"""Projected Distribution of Q-values
127
127
@@ -140,8 +140,10 @@ def categorical_q_target(
140
140
support = torch .linspace (agent .v_min , agent .v_max , agent .num_atoms )
141
141
142
142
next_q_value_dist = agent .target_model (next_states ) * support
143
- next_actions = torch .argmax (next_q_value_dist .sum (- 1 ), axis = - 1 )
144
- next_actions = next_actions [:, :, np .newaxis , np .newaxis ]
143
+ next_actions = (
144
+ torch .argmax (next_q_value_dist .sum (- 1 ), axis = - 1 ).unsqueeze (- 1 ).unsqueeze (- 1 )
145
+ )
146
+
145
147
next_actions = next_actions .expand (
146
148
agent .batch_size , agent .env .n_envs , 1 , agent .num_atoms
147
149
)
0 commit comments