-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmdp_Q_learning.r
107 lines (85 loc) · 2.45 KB
/
mdp_Q_learning.r
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
mdp_Q_learning <- function(P, R, discount, N) {
# check of arguments
if ( discount <= 0 | discount > 1 ) {
print('--------------------------------------------------------')
print('MDP Toolbox ERROR: Discount rate must be in ]0; 1]')
print('--------------------------------------------------------')
} else if ( nargs() >= 4 & ifelse(!missing(N), N < 10000, F) ) {
print('--------------------------------------------------------')
print('MDP Toolbox ERROR: N must be upper than 10000')
print('--------------------------------------------------------')
} else {
# initialization of optional argument
if (nargs() < 4) {
N <- 10000
}
# Assigning the number of states and actions
if (is.list(P)) {
S <- dim(P[[1]])[1]
A <- length(P)
} else {
S <- dim(P)[1]
A <- dim(P)[3]
}
# Initializations
Q <- matrix(0,S,A)
dQ <- matrix(0,S,A)
mean_discrepancy <- NULL
discrepancy <- NULL
# Initial state choice
state <- sample(1:S, 1, replace=T)
for (n in 1:N) {
# Reinitialisation of trajectories every 100 transitions
if ( n %% 100 == 0 ) {
state <- sample(1:S, 1, replace=T)
}
# Action choice : greedy with increasing probability
# probability 1-(1/log(n+2)) can be changed
pn <- runif(1)
if ( pn < (1-(1/log(n+2))) ) {
optimal_action <- max(Q[state,])
a <- which.max(Q[state,])
} else {
a <- sample(1:A, 1, replace=T)
}
# Simulating next state s_new and reward associated to <s,s_new,a>
p_s_new <- runif(1)
p <- 0
s_new <- 0
while ((p < p_s_new) & (s_new < S)) {
s_new <- s_new + 1
if (is.list(P)) {
p <- p + P[[a]][state,s_new]
} else {
p <- p + P[state,s_new,a]
}
}
if (is.list(R)) {
r <- R[[a]][state,s_new]
} else {
if(length(dim(R)) == 3) {
r <- R[state,s_new,a]
} else {
r <- R[state,a]
}
}
# Updating the value of Q
# Decaying update coefficient (1/sqrt(n+2)) can be changed
delta <- r + discount*max(Q[s_new,]) - Q[state,a]
dQ <- (1/sqrt(n+2))*delta
Q[state,a] <- Q[state,a] + dQ
# Current state is updated
state <- s_new
# Computing and saving maximal values of the Q variation
discrepancy[(n %% 100) + 1] = abs(dQ)
if(length(discrepancy) == 100) {
mean_discrepancy <- c(mean_discrepancy, mean(discrepancy))
discrepancy <- NULL
}
}
# compute the value function and the policy
V <- apply(Q, 1, max)
policy <- apply(Q, 1, which.max)
}
return(list("Q"=Q, "V"=V, "policy"=policy, "mean_discrepancy"=mean_discrepancy))
}