forked from Wei-1/Scala-Machine-Learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPERTest.scala
68 lines (61 loc) · 2.51 KB
/
PERTest.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// Wei Chen - Deep Q Network
// 2017-09-01
import com.scalaml.TestData._
import com.scalaml.algorithm.PER
import org.scalatest.funsuite.AnyFunSuite
class PERSuite extends AnyFunSuite {
val learning_rate = 0.1
val scale = 1
val limit = 10000
val epoch = 100
test("PER Test : Result 1") { // Case 1
def simulator(paras: Array[Double], act: Int): (Array[Double], Double, Boolean) = {
val links = Map(0 -> Array(1, 2),
1 -> Array(3, 4))
val scores = Map(2 -> 10.0, 3 -> 0.0, 4 -> 100.0)
val atloc = paras.zipWithIndex.maxBy(_._1)._2
val moves = links.getOrElse(atloc, Array[Int]())
if (moves.size == 0) {
null
} else {
val endloc = moves(act)
val result = Array(0.0, 0.0, 0.0, 0.0, 0.0)
result(endloc) = 1.0
val nextmoves = links.getOrElse(endloc, Array[Int]())
(result, scores.getOrElse(endloc, 0.0), nextmoves.size == 0)
}
}
val ql = new PER(Array(5, 4), Array(1.0, 0.0, 0.0, 0.0, 0.0), 2, simulator, 10)
ql.train(limit, learning_rate, scale, epoch)
val result = ql.result(epoch)
assert(result.size == 3)
assert(result.head.bestAct == 0)
assert(result(1).bestAct == 1)
assert(result.last.paras.zipWithIndex.maxBy(_._1)._2 == 4)
}
test("PER Test : Result 2") { // Case 2
def simulator(paras: Array[Double], act: Int): (Array[Double], Double, Boolean) = {
val links = Map(0 -> Array(1, 2),
1 -> Array(3, 4))
val scores = Map(2 -> 10.0, 3 -> 0.0, 4 -> 12.0)
val atloc = paras.zipWithIndex.maxBy(_._1)._2
val moves = links.getOrElse(atloc, Array[Int]())
if (moves.size == 0) {
null
} else {
val endloc = moves(act)
val result = Array(0.0, 0.0, 0.0, 0.0, 0.0)
result(endloc) = 1.0
val nextmoves = links.getOrElse(endloc, Array[Int]())
(result, scores.getOrElse(endloc, 0.0), nextmoves.size == 0)
}
}
val ql = new PER(Array(5, 4), Array(1.0, 0.0, 0.0, 0.0, 0.0), 2, simulator, 10)
ql.train(limit, learning_rate, scale, epoch)
val result = ql.result(epoch)
assert(result.size == 3)
assert(result.head.bestAct == 0)
assert(result(1).bestAct == 1)
assert(result.last.paras.zipWithIndex.maxBy(_._1)._2 == 4)
}
}