forked from Wei-1/Scala-Machine-Learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathQNeuralLearning.scala
64 lines (59 loc) · 2.34 KB
/
QNeuralLearning.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
// Wei Chen - Q Neural Learning
// 2017-08-16
package com.scalaml.algorithm
// nextstate, reward, end = simulator(state, action)
class QNeuralLearning(
val layer_neurons: Array[Int],
val initparas: Array[Double],
val actnumber: Int,
val simulator: (Array[Double], Int) => (Array[Double], Double, Boolean),
val batchsize_number: Int = 100,
val epsilon_saturation_number: Int = 10000,
val nn_learning_rate: Double = 0.01
) {
val nn = new NeuralNetwork()
nn.config(initparas.size +: layer_neurons :+ actnumber,
_batchSize = batchsize_number, _gradientClipping = true)
class QNState (val paras: Array[Double]) {
def learn(lr: Double, df: Double, epoch: Int): Double = {
val q_s = nn.predictOne(paras)
val act = (if (scala.util.Random.nextDouble > epsilon) q_s.zipWithIndex.maxBy(_._1)._2
else scala.util.Random.nextInt.abs % actnumber)
if (epsilon > 0.1) epsilon -= depsilon
val (newparas, newreward, newfinish) = simulator(paras, act)
if (epoch > 0 && !newfinish) {
val newstate = new QNState(newparas)
val gradient = newreward + df * newstate.learn(lr, df, epoch - 1) // max -> a: Q(s+1, a)
q_s(act) = (1 - lr) * q_s(act) + lr * gradient
} else {
q_s(act) = newreward
}
nn.train(Array(paras), Array(q_s), _learningRate = nn_learning_rate)
q_s.max
}
val bestAct: Int = nn.predictOne(paras).zipWithIndex.maxBy(_._1)._2
}
var epsilon = 1.0
var depsilon = 0.9 / epsilon_saturation_number
var state = new QNState(initparas)
def train(number: Int = 1, lr: Double = 0.1, df: Double = 0.6, epoch: Int = 100): Unit = {
for (n <- 0 until number)
state.learn(lr, df, epoch)
}
def result(epoch: Int = 100): Array[QNState] = {
var paras = initparas
var curstate = new QNState(initparas)
var arr: Array[QNState] = Array(curstate)
var i = 0
while (i < epoch) {
i += 1
val act = curstate.bestAct
val (newparas, newreward, newfinish) = simulator(paras, act)
if (newfinish) i = epoch
paras = newparas
curstate = new QNState(newparas)
arr :+= curstate
}
arr
}
}