-
Notifications
You must be signed in to change notification settings - Fork 0
/
sarfa_saliency.py
101 lines (79 loc) · 4.44 KB
/
sarfa_saliency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Source: https://github.com/nikaashpuri/sarfa-saliency
import math
import numpy as np
from scipy.stats import entropy, wasserstein_distance
from scipy.spatial.distance import jensenshannon
def your_softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
def cross_entropy(dictP, dictQ, original_action):
"""
This function calculates normalized cross entropy (KL divergence) of Q-values of state Q wrt state P.
Input:
dictP: Q-value dictionary of perturbed state
dictQ: Q-value dictionary of original state
Output:p = policy[:best_move+1]
p = np.append(p, policy[best_move+1:])
K: normalized cross entropy
"""
Dpq = 0.
Q_p = [] #values of moves in dictP^dictQ wrt P
Q_q = [] #values of moves in dictP^dictQ wrt Q
for move in dictP:
if move == original_action:
#print('skipping original action for KL-Divergence')
continue
if move in dictQ:
Q_p.append(dictP[move])
Q_q.append(dictQ[move])
# converting Q-values into probability distribution
Q_p = your_softmax(np.asarray(Q_p))
Q_q = your_softmax(np.asarray(Q_q))
KL = entropy(Q_q, Q_p)
# KL = wasserstein_distance(Q_q, Q_p)
#return (KL)/(KL + 1.)
return 1./(KL + 1.)
def computeSaliencyUsingSarfa(original_action, dict_q_vals_before_perturbation, dict_q_vals_after_perturbation):
answer = 0
# probability of original move in perturbed state
q_value_action_perturbed_state = dict_q_vals_after_perturbation[original_action]
q_value_action_original_state = dict_q_vals_before_perturbation[original_action]
q_values_after_perturbation = np.asarray(list(dict_q_vals_after_perturbation.values()))
q_values_before_perturbation = np.asarray(list(dict_q_vals_before_perturbation.values()))
# I changed the next two lines, because with our q-values they overflow. The result is exactly the same
probability_action_perturbed_state = your_softmax(q_values_after_perturbation)[original_action] #np.exp(q_value_action_perturbed_state) / np.sum(np.exp(q_values_after_perturbation))
probability_action_original_state = your_softmax(q_values_before_perturbation)[original_action] #np.exp(q_value_action_original_state) / np.sum(np.exp(q_values_before_perturbation))
K = cross_entropy(dict_q_vals_after_perturbation, dict_q_vals_before_perturbation, original_action)
dP = probability_action_original_state - probability_action_perturbed_state
if probability_action_perturbed_state < probability_action_original_state: # harmonic mean
answer = 2*dP*K/(dP + K)
# I made the next two comments, because they are computationally expensive but I do not use them
#QmaxAnswer = computeSaliencyUsingQMaxChange(original_action, dict_q_vals_before_perturbation, dict_q_vals_after_perturbation)
#action_gap_before_perturbation, action_gap_after_perturbation = computeSaliencyUsingActionGap(dict_q_vals_before_perturbation, dict_q_vals_after_perturbation)
# print("Delta P = ", dP)
# print("KL normalized = ", K)
# print("KL normalized inverse = ", 1/K)
# print(entry['saliency'])
# The rest of the return I do not need and part of it do not compute for efficiency reasons
return answer #, dP, K, QmaxAnswer, action_gap_before_perturbation, action_gap_after_perturbation
def computeSaliencyUsingQMaxChange(original_action, dict_q_vals_before_perturbation, dict_q_vals_after_perturbation):
answer = 0
best_action = None
best_q_value = 0
for move, q_value in dict_q_vals_after_perturbation.items():
if best_action is None:
best_action = move
best_q_value = q_value
elif q_value > best_q_value:
best_q_value = q_value
best_action = move
if best_action != original_action:
answer = 1
return answer
def computeSaliencyUsingActionGap(dict_q_vals_before_perturbation, dict_q_vals_after_perturbation):
q_vals_before_perturbation = sorted(dict_q_vals_before_perturbation.values())
q_vals_after_perturbation = sorted(dict_q_vals_after_perturbation.values())
action_gap_before_perturbation = q_vals_before_perturbation[-1] - q_vals_before_perturbation[-2]
action_gap_after_perturbation = q_vals_after_perturbation[-1] - q_vals_after_perturbation[-2]
return action_gap_before_perturbation, action_gap_after_perturbation