forked from Wei-1/Scala-Machine-Learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathA3CTest.scala
63 lines (54 loc) · 2.3 KB
/
A3CTest.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// Wei Chen - Asynchronous Advantage Actor Critic
// 2017-10-01
import com.scalaml.TestData._
import com.scalaml.algorithm.A3C
import org.scalatest.funsuite.AnyFunSuite
class A3CSuite extends AnyFunSuite {
val scale = 1
val limit = 10000
val epoch = 100
test("A3C Test : Result 1") { // Not Finished Yet
def simulator(paras: Array[Double], act: Int): (Array[Double], Double, Boolean) = {
val links = Map(0 -> Array(1, 2),
1 -> Array(3, 4))
val scores = Map(2 -> 10.0, 3 -> 0.0, 4 -> 100.0)
val atloc = paras.zipWithIndex.maxBy(_._1)._2
val moves = links.getOrElse(atloc, Array[Int]())
if (moves.size == 0) {
null
} else {
val endloc = moves(act)
val result = Array(0.0, 0.0, 0.0, 0.0, 0.0)
result(endloc) = 1.0
val nextmoves = links.getOrElse(endloc, Array[Int]())
(result, scores.getOrElse(endloc, 0.0), nextmoves.size == 0)
}
}
val ql = new A3C(Array(5, 4), Array(5, 4), Array(1.0, 0.0, 0.0, 0.0, 0.0), 2, simulator, 10)
ql.train(limit, scale, epoch)
val result = ql.result(epoch)
assert(result.last.paras.zipWithIndex.maxBy(_._1)._2 == 4)
}
test("A3C Test : Result 2") { // Not Finished Yet
def simulator(paras: Array[Double], act: Int): (Array[Double], Double, Boolean) = {
val links = Map(0 -> Array(1, 2),
1 -> Array(3, 4))
val scores = Map(2 -> 10.0, 3 -> 0.0, 4 -> 12.0)
val atloc = paras.zipWithIndex.maxBy(_._1)._2
val moves = links.getOrElse(atloc, Array[Int]())
if (moves.size == 0) {
null
} else {
val endloc = moves(act)
val result = Array(0.0, 0.0, 0.0, 0.0, 0.0)
result(endloc) = 1.0
val nextmoves = links.getOrElse(endloc, Array[Int]())
(result, scores.getOrElse(endloc, 0.0), nextmoves.size == 0)
}
}
val ql = new A3C(Array(5, 4), Array(5, 4), Array(1.0, 0.0, 0.0, 0.0, 0.0), 2, simulator, 10)
ql.train(limit, scale, epoch)
val result = ql.result(epoch)
assert(result.last.paras.zipWithIndex.maxBy(_._1)._2 == 4)
}
}