-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathJavaHdfsLR-v1.java
120 lines (98 loc) · 3.04 KB
/
JavaHdfsLR-v1.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
import spark.api.java.function.Function;
import spark.api.java.function.Function2;
import java.io.Serializable;
import java.util.Arrays;
import java.util.StringTokenizer;
import java.util.Random;
public class JavaHdfsLR {
static int D = 16428; // Number of dimensions
static Random rand = new Random(42);
static class DataPoint implements Serializable {
public DataPoint(double[] x, double y) {
this.x = x;
this.y = y;
}
double[] x;
double y;
}
static class ParsePoint extends Function<String, DataPoint> {
public DataPoint call(String line) {
StringTokenizer itr = new StringTokenizer(line, " ");
double y = Double.parseDouble(itr.nextToken());
double[] x = new double[D];
String tmp=itr.nextToken();
while (tmp.contains(":")) {
String[] strs=tmp.split(":");
x[Integer.parseInt(strs[0])-1] = Double.parseDouble(strs[1]);
if (itr.hasMoreTokens()) tmp=itr.nextToken();
else break;
}
return new DataPoint(x, y);
}
}
static class VectorSum extends Function2<double[], double[], double[]> {
public double[] call(double[] a, double[] b) {
double[] result = new double[D];
for (int j = 0; j < D; j++) {
result[j] = a[j] + b[j];
}
return result;
}
}
static class ComputeGradient extends Function<DataPoint, double[]> {
double[] weights;
public ComputeGradient(double[] weights) {
this.weights = weights;
}
public double[] call(DataPoint p) {
double[] gradient = new double[D];
for (int i = 0; i < D; i++) {
double dot = dot(weights, p.x);
gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i];
}
return gradient;
}
}
public static double dot(double[] a, double[] b) {
double x = 0;
for (int i = 0; i < D; i++) {
x += a[i] * b[i];
}
return x;
}
public static void printWeights(double[] a) {
//System.out.println(Arrays.toString(a));
}
public static void main(String[] args) {
if (args.length < 3) {
System.err.println("Usage: JavaHdfsLR <master> <file> <iters>");
System.exit(1);
}
JavaSparkContext sc = new JavaSparkContext(args[0], "JavaHdfsLR",
System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
JavaRDD<String> lines = sc.textFile(args[1]);
JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache();
int ITERATIONS = Integer.parseInt(args[2]);
// Initialize w to a random value
double[] w = new double[D];
for (int i = 0; i < D; i++) {
w[i] = 2 * rand.nextDouble() - 1;
}
System.out.print("Initial w: ");
printWeights(w);
for (int i = 1; i <= ITERATIONS; i++) {
System.out.println("On iteration " + i);
double[] gradient = points.map(
new ComputeGradient(w)
).reduce(new VectorSum());
for (int j = 0; j < D; j++) {
w[j] -= gradient[j];
}
}
System.out.print("Final w: ");
printWeights(w);
System.exit(0);
}
}