-
Notifications
You must be signed in to change notification settings - Fork 1
/
softmax_policy.py
29 lines (23 loc) · 1.03 KB
/
softmax_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import torch
from agents.policies.base_policy import Policy
from tools.rl_constants import Action
class SoftmaxPolicy(Policy):
def __init__(self, action_size: int, seed: int = None):
super().__init__(action_size=action_size)
self.action_size = action_size
def get_action(self, state: np.array, model: torch.nn.Module) -> Action:
""" Implement this function for speed"""
model.eval()
with torch.no_grad():
action_values = model.forward(state, act=True)
model.train()
probs = torch.nn.functional.softmax(action_values)
action = np.array([np.random.choice(np.arange(0, self.action_size), p=probs.view(-1).numpy())])
action = Action(value=action)
return action
def get_deterministic_policy(self, state_action_values_dict: dict):
deterministic_policy = {}
for state in state_action_values_dict:
deterministic_policy[state] = np.argmax(state_action_values_dict[state])
return deterministic_policy