-
Notifications
You must be signed in to change notification settings - Fork 1
/
epsilon_greedy_with_softmax.py
53 lines (42 loc) · 1.84 KB
/
epsilon_greedy_with_softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import torch
from tools.parameter_scheduler import ParameterScheduler
from agents.policies.base_policy import Policy
import random
from tools.rl_constants import Action
class EpsilonGreedySoftmaxPolicy(Policy):
def __init__(self, action_size: int, epsilon_scheduler: ParameterScheduler, seed: int = None):
super().__init__(action_size=action_size)
self.epsilon_scheduler = epsilon_scheduler
self.action_size = action_size
# Initialize epsilon
self.epsilon = self.epsilon_scheduler.initial
if seed:
self.set_seed(seed)
def step(self, episode_number: int):
self.epsilon = self.epsilon_scheduler.get_param(episode_number)
return True
def get_action(self, state: np.array, model: torch.nn.Module) -> Action:
""" Implement this function for speed"""
def _get_action_values():
model.eval()
with torch.no_grad():
action_values = model.forward(state, act=True)
model.train()
return action_values
if self.training:
action_values_ = _get_action_values()
if random.random() > self.epsilon:
action = action_values_.max(1)[1].data[0]
else:
probs = torch.nn.functional.softmax(action_values_)
action = np.random.choice(np.arange(0, self.action_size), p=probs.view(-1).numpy())
else:
action_values_ = _get_action_values()
action = action_values_.max(1)[1].data[0]
return Action(value=action)
def get_deterministic_policy(self, state_action_values_dict: dict):
deterministic_policy = {}
for state in state_action_values_dict:
deterministic_policy[state] = np.argmax(state_action_values_dict[state])
return deterministic_policy