-
Notifications
You must be signed in to change notification settings - Fork 0
/
RunNetwork.java
66 lines (56 loc) · 2.29 KB
/
RunNetwork.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
/*
* Authored by Bennett Liu on October 15th, 2019
*
* The RunNetwork contains a main function imports a Network from a file, containing an exported Network.
* It then takes in test cases, which are run on the imported Network.
*/
import java.util.*;
import java.io.*;
public class RunNetwork
{
public static void main(String[] args)
{
int inputNodes; // The number of inputs
int outputNodes; // The number of outputs
int testCases; // The number of test cases
double testInputs[][]; // The inputs for each test case
double testOutputs[][]; // The outputs for each test case
Scanner in = new Scanner(System.in); // Create scanner to take input from console
// Import network from file
System.out.println("Enter the file that you'd like to import your network from: ");
String fileName = in.next();
Network network = new Network(new File(fileName));
inputNodes = network.inputs;
outputNodes = network.outputs;
/*
* Read test case inputs, namely:
*
* The number of test cases
* Each test case's inputs
* Each test case's outputs
*/
System.out.println("How many test cases: ");
testCases = in.nextInt();
testInputs = new double[testCases][inputNodes];
testOutputs = new double[testCases][outputNodes];
for (int i = 1; i <= testCases; i++)
{
System.out.println(String.format("Test Case %d", i));
for (int j = 1; j <= inputNodes; j++)
{
System.out.println(String.format("Input %d:", j));
testInputs[i - 1][j - 1] = in.nextDouble();
}
for (int j = 1; j <= outputNodes; j++)
{
System.out.println(String.format("Output %d:", j));
testOutputs[i - 1][j - 1] = in.nextDouble();
}
} // for (int i = 1; i <= testCases; i++)
// Initialize trainer and evaluate the initial network for all test cases
NetworkTrainer trainer = new NetworkTrainer(network, testInputs, testOutputs);
trainer.printResults();
in.close(); // Close scanner
return;
} // public static void main(String[] args)
} // public class RunNetwork