Skip to content

Commit af4fd44

Browse files
committed
Best state that could be reached.
Update README.md Update README.md Update README.md
1 parent 66a5f9f commit af4fd44

File tree

9 files changed

+183
-88
lines changed

9 files changed

+183
-88
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,9 @@
11
# AI Project
22
Project for AI Course
3+
This project is supposed to recognize the types of clothes, given the pictures of them. We used "Perceptrons", "MIRA" and "Kernelized Perceptrons" approaches to get to this objective.
4+
5+
Further reading: https://github.com/zalandoresearch/fashion-mnist
6+
***
7+
Contributors:
8+
+ Mahsa Sheikhi
9+
+ Kiarash Golzadeh

src/Resource/Train/train_data2.txt

Lines changed: 22 additions & 0 deletions
Large diffs are not rendered by default.

src/Resource/Train/train_data_mira2.txt

Lines changed: 21 additions & 0 deletions
Large diffs are not rendered by default.

src/Source/Classifier.java

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
public abstract class Classifier {
1010
protected int[] labels;
1111
protected List<int[][]> images = new LinkedList<>();
12-
protected ArrayList<float[]> weight = new ArrayList<>();
12+
protected ArrayList<double[]> weight = new ArrayList<>();
13+
protected double[] mean = new double[Factors.NUMBER_OF_FACTORS], minValue = new double[Factors.NUMBER_OF_FACTORS],
14+
maxValue = new double[Factors.NUMBER_OF_FACTORS];
15+
protected int heldOutCnt = 0;
16+
protected int numberOfImages = 60000 - heldOutCnt ;
17+
protected Factors[] factorsList = new Factors[numberOfImages];
1318

1419
public Classifier(int[] labels, List<int[][]> images) {
1520
this.labels = labels;
@@ -32,32 +37,58 @@ public Classifier(String trainData) {
3237
if (weight.size() == 0)
3338
{
3439
for (int i = 0; i < Label.values().length; i++)
35-
weight.add(new float[Factors.NUMBER_OF_FACTORS]);
40+
weight.add(new double[Factors.NUMBER_OF_FACTORS]);
3641
}
3742
for (int i = 0; i < weightCnt; i++)
3843
{
3944
int weightInd = Integer.parseInt(scanner.nextLine().trim().split(" ")[1]);
4045
for (int j = 0; j < Factors.NUMBER_OF_FACTORS; j++)
4146
{
42-
float weightParameter = Float.parseFloat(scanner.next());
47+
double weightParameter = Double.parseDouble(scanner.next());
4348
weight.get(weightInd)[j] = weightParameter;
4449
}
4550
scanner.nextLine();
4651
}
47-
52+
for (int i = 0; i < Factors.NUMBER_OF_FACTORS; i++)
53+
{
54+
mean[i] = Double.parseDouble(scanner.next());
55+
minValue[i] = Double.parseDouble(scanner.next());
56+
maxValue[i] = Double.parseDouble(scanner.next());
57+
}
4858
}
4959

5060
public Classifier() {
5161
}
52-
62+
void normalize()
63+
{
64+
for (int i = 0; i < Factors.NUMBER_OF_FACTORS; i++)
65+
{
66+
maxValue[i] = Double.MIN_VALUE;
67+
minValue[i] = Double.MAX_VALUE;
68+
}
69+
for (Factors factors : factorsList)
70+
{
71+
double[] factorList = factors.getFactors();
72+
for (int i = 0; i < Factors.NUMBER_OF_FACTORS; i++)
73+
{
74+
mean[i] += factorList[i];
75+
maxValue[i] = (factorList[i] > maxValue[i] ? factorList[i] : maxValue[i]);
76+
minValue[i] = (factorList[i] < minValue[i] ? factorList[i] : minValue[i]);
77+
}
78+
}
79+
for (int i = 0; i < mean.length; i++)
80+
mean[i] /= numberOfImages;
81+
for (Factors factors : factorsList)
82+
factors.normalize(mean, minValue, maxValue);
83+
}
5384
abstract public void train();
5485

5586
abstract public Label test(int[][] image);
5687

5788
public void printWeights() {
5889
for (int i = 0; i < weight.size(); i++)
5990
{
60-
float[] factors = weight.get(i);
91+
double[] factors = weight.get(i);
6192
System.out.printf("Factor #%d :\n", i);
6293
for (int ind = 0; ind < factors.length; ind++)
6394
System.out.println("factor " + ind + " = " + factors[ind]);
@@ -77,12 +108,14 @@ public void printWeightsToFile(String trainData) {
77108
printWriter.println(weight.size());
78109
for (int i = 0; i < weight.size(); i++)
79110
{
80-
float[] factors = weight.get(i);
111+
double[] factors = weight.get(i);
81112
printWriter.printf("Factor %d :\n", i);
82-
for (float factor : factors)
113+
for (double factor : factors)
83114
printWriter.print(factor + " ");
84115
printWriter.println();
85116
}
117+
for (int i = 0; i < Factors.NUMBER_OF_FACTORS; i++)
118+
printWriter.println(mean[i] + " " + minValue[i] + " " + maxValue[i]);
86119
printWriter.flush();
87120
printWriter.close();
88121
}

src/Source/Factors.java

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
import java.util.Queue;
99

1010
public class Factors {
11+
public void normalize(double[] mean, double[] minValue, double[] maxValue) {
12+
for (int i = 0; i < NUMBER_OF_FACTORS; i++)
13+
factors[i] = (factors[i] - mean[i]) /255;// ((maxValue[i] - minValue[i]) < 1e-6 ? 1 : (maxValue[i] - minValue[i]));
14+
}
15+
1116
class IntPair extends AbstractMap.SimpleEntry<Integer, Integer> {
1217
public IntPair(Integer key, Integer value) {
1318
super(key, value);
@@ -16,7 +21,7 @@ public IntPair(Integer key, Integer value) {
1621

1722
// FIXME: Change this parameter when changing factors; otherwise it doesn't work properly.
1823
public static final int NUMBER_OF_FACTORS = 28*28 + 8 + 1;
19-
private float[] factors = new float[NUMBER_OF_FACTORS];
24+
private double[] factors = new double[NUMBER_OF_FACTORS];
2025
private int[][] image;
2126
private int height, width;
2227
private boolean factorsCalculated = false;
@@ -29,7 +34,7 @@ public Factors( int[][] image) {
2934
private void calcFactors() {
3035
//This divisions are there to simplify the process of dividing the picture to parts
3136
factorsCalculated = true;
32-
float result;
37+
double result;
3338
factors[0] = 1;
3439
int index = 1;
3540
final int division = 1;
@@ -63,14 +68,14 @@ private void calcFactors() {
6368
factors[index++] = result;
6469
}
6570

66-
public float[] getFactors() {
71+
public double[] getFactors() {
6772
if (!factorsCalculated)
6873
calcFactors();
6974
return factors;
7075
}
7176
// TODO: 09/06/2019 Implement all factors
7277

73-
private float topBottomRatio() {
78+
private double topBottomRatio() {
7479
int topPixels = 0, bottomPixels = 0;
7580
for (int i = 0; i < height; i++)
7681
for (int j = 0; j < width; j++)
@@ -81,10 +86,10 @@ private float topBottomRatio() {
8186
else bottomPixels++;
8287
}
8388
if (bottomPixels == 0) bottomPixels = 1;
84-
return topPixels / (float)(1.0 * bottomPixels);
89+
return topPixels / (double)(1.0 * bottomPixels);
8590
}
8691

87-
private float heightwidthRatio() {
92+
private double heightwidthRatio() {
8893
int h = 0, w = 0;
8994
int minHeight = 0, maxHeight= 0, minWidth= 0, maxWidth= 0;
9095

@@ -112,10 +117,10 @@ private float heightwidthRatio() {
112117
w = maxWidth - minWidth;
113118

114119
if (w == 0) w = 1;
115-
return h / (float)(1.0 * w);
120+
return h / (double)(1.0 * w);
116121
}
117122

118-
private float backgroundColorNumber() {
123+
private double backgroundColorNumber() {
119124
int start = height / 3, end = 2 * height / 3 , sum = 0;
120125
for (int i = start; i <= end; i++)
121126
for (int j = 0; j < width; ){
@@ -130,10 +135,10 @@ private float backgroundColorNumber() {
130135

131136
}
132137

133-
return (float)(1.0 * sum) / (float)(end - start +1);
138+
return (double)(1.0 * sum) / (double)(end - start +1);
134139
}
135140

136-
private float hasMoreBackgroundPixels() {
141+
private double hasMoreBackgroundPixels() {
137142
int pixels = 0;
138143
for (int i = 0; i < height; i++)
139144
for (int j = 0; j < width; j++)
@@ -146,7 +151,7 @@ private float hasMoreBackgroundPixels() {
146151
else return 1;
147152
}
148153

149-
private float leftRightRatio() {
154+
private double leftRightRatio() {
150155
int leftPixels = 0, rightPixels = 0;
151156
for (int i = 0; i < height; i++)
152157
for (int j = 0; j < width; j++)
@@ -157,12 +162,12 @@ private float leftRightRatio() {
157162
else rightPixels++;
158163
}
159164
if (rightPixels == 0) rightPixels = 1;
160-
return leftPixels / (float)(1.0 * rightPixels);
165+
return leftPixels / (double)(1.0 * rightPixels);
161166
}
162167

163168
//Calculate color change of consecutive pixels
164-
private float colorChange() {
165-
float ret = 0;
169+
private double colorChange() {
170+
double ret = 0;
166171
int differenceSum = 0;
167172
for (int i = 0; i < height; i++)
168173
for (int j = 0; j < width - 1; j++)
@@ -172,11 +177,11 @@ private float colorChange() {
172177
for (int i = 0; i < height - 1; i++)
173178
if (image[i][j] > 0 && image[i + 1][j] > 0)
174179
differenceSum += Math.abs(image[i][j] - image[i + 1][j]);
175-
ret = differenceSum / (float)(1.0 * (width - 1) * (height) + (width - 1) * height);
180+
ret = differenceSum / (double)(1.0 * (width - 1) * (height) + (width - 1) * height);
176181
return ret;
177182
}
178183

179-
private float ratioOfPixelsUnderSecondaryDiagonal() {
184+
private double ratioOfPixelsUnderSecondaryDiagonal() {
180185
int pixelCount = 0, underDiagonalPixelCount = 0;
181186
for (int i = 0; i < height; i++)
182187
for (int j = 0; j < width; j++)
@@ -189,12 +194,12 @@ private float ratioOfPixelsUnderSecondaryDiagonal() {
189194
}
190195
}
191196
if (pixelCount == 0) pixelCount = 1;
192-
return underDiagonalPixelCount / (float)(1.0 * pixelCount);
197+
return underDiagonalPixelCount / (double)(1.0 * pixelCount);
193198
}
194199

195200

196201
private IntPair findCenterOfMass() {
197-
float sumOfPoints = 0, weightedSumX = 0, weightedSumY = 0;
202+
double sumOfPoints = 0, weightedSumX = 0, weightedSumY = 0;
198203
for (int i = 0; i < height; i++)
199204
for (int j = 0; j < width; j++)
200205
{
@@ -208,7 +213,7 @@ private IntPair findCenterOfMass() {
208213
}
209214

210215
public int numberOfRings() {
211-
float[][] copyImage = new float[height][width];
216+
double[][] copyImage = new double[height][width];
212217
for (int i = 0; i < height; i++)
213218
for (int j = 0; j < width; j++)
214219
copyImage[i][j] = image[i][j];

src/Source/KernelizedPerceptron.java

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,42 @@
33
import java.io.File;
44
import java.io.FileNotFoundException;
55
import java.io.PrintWriter;
6-
import java.util.List;
7-
import java.util.Scanner;
6+
import java.util.*;
87

98
public class KernelizedPerceptron extends Perceptron {
10-
// float[][] kernelList = new float[numberOfImages][numberOfImages];
9+
private ArrayList<HashSet<Integer>> nonZeroAlpha = new ArrayList<>();
1110
public KernelizedPerceptron(int[] labels, List<int[][]> images) {
1211
super(labels, images);
1312
weight.clear();
1413
for (int i = 0; i < Label.values().length; i++)
15-
weight.add(new float[numberOfImages]);
16-
// for (int i = 0; i < numberOfImages; i++)
17-
// for (int j = 0; j < numberOfImages; j++)
18-
// kernelList[i][j] = kernel(factorsList[i], factorsList[j]);
14+
{
15+
weight.add(new double[numberOfImages]);
16+
nonZeroAlpha.add(new HashSet<>());
17+
}
1918
}
2019

21-
22-
2320
private double kernel(int iFactorIndex, int testFactorIndex) {
2421
double ret = 0;
25-
float[] iFactorFloat = factorsList[iFactorIndex].getFactors(),
26-
testFactorFloat = factorsList[testFactorIndex].getFactors();
22+
double[] iFactorDouble = factorsList[iFactorIndex].getFactors(),
23+
testFactorDouble = factorsList[testFactorIndex].getFactors();
2724
for (int i = 0; i < Factors.NUMBER_OF_FACTORS; i++)
28-
ret += iFactorFloat[i] * testFactorFloat[i];
25+
ret += iFactorDouble[i] * testFactorDouble[i];
2926
return (ret + 1) * (ret + 1);
3027
}
3128

3229
@Override
3330
public Label test(int[][] image) {
3431
Factors factors = new Factors(image);
32+
factors.normalize(mean, minValue, maxValue);
3533
return decideLabel(factors);
3634
}
3735

3836
private double kernel(int iFactorIndex, Factors test) {
3937
double ret = 0;
40-
float[] iFactorFloat = factorsList[iFactorIndex].getFactors(),
41-
testFactorFloat = test.getFactors();
38+
double[] iFactorDouble = factorsList[iFactorIndex].getFactors(),
39+
testFactorDouble = test.getFactors();
4240
for (int i = 0; i < Factors.NUMBER_OF_FACTORS; i++)
43-
ret += iFactorFloat[i] * testFactorFloat[i];
41+
ret += iFactorDouble[i] * testFactorDouble[i];
4442
return (ret + 1) * (ret + 1);
4543
}
4644
@Override
@@ -52,9 +50,9 @@ protected Label decideLabel(Factors factors) {
5250
for (Label label : Label.values())
5351
{
5452
dotProduct = 0;
55-
for (int image = 0; image <= maxImageIndex; image++)
53+
for (int image : nonZeroAlpha.get(label.ordinal()))
5654
{
57-
float alpha = weight.get(label.ordinal())[image];
55+
double alpha = weight.get(label.ordinal())[image];
5856
if (alpha != 0)
5957
dotProduct += weight.get(label.ordinal())[image] * kernel(image, factors);
6058
}
@@ -77,9 +75,9 @@ protected Label decideLabel(int factorsIndex) {
7775
for (Label label : Label.values())
7876
{
7977
dotProduct = 0;
80-
for (int image = 0; image <= maxImageIndex; image++)
78+
for (int image : nonZeroAlpha.get(label.ordinal()))
8179
{
82-
float alpha = weight.get(label.ordinal())[image];
80+
double alpha = weight.get(label.ordinal())[image];
8381
if (alpha != 0)
8482
dotProduct += weight.get(label.ordinal())[image] * kernel(image, factorsIndex);
8583
}
@@ -96,7 +94,11 @@ protected Label decideLabel(int factorsIndex) {
9694
@Override
9795
protected void updateWeights(int image, Label decidedLabel) {
9896
weight.get(decidedLabel.ordinal())[image]--;
97+
if (Math.abs (weight.get(decidedLabel.ordinal())[image]-0) > 1e-6)
98+
nonZeroAlpha.get(decidedLabel.ordinal()).add(image);
9999
weight.get(labels[image])[image]++;
100+
if (Math.abs (weight.get(labels[image])[image]-0) > 1e-6)
101+
nonZeroAlpha.get(labels[image]).add(image);
100102
}
101103

102104
@Override
@@ -115,14 +117,12 @@ public void printWeightsToFile(String trainData) {
115117
printWriter.println(weight.size());
116118
for (int i = 0; i < weight.size(); i++)
117119
{
118-
float[] factors = weight.get(i);
120+
double[] factors = weight.get(i);
119121
printWriter.printf("Label %d :\n", i);
120-
for (float factor : factors)
122+
for (double factor : factors)
121123
printWriter.print(factor + " ");
122124
printWriter.println();
123125
}
124-
125-
126126
printWriter.flush();
127127
printWriter.close();
128128
}
@@ -141,21 +141,30 @@ void loadWeightsFromFile(String trainData)
141141
// Here I tried to read the input train data file properly :)
142142
int imagesCnt = Integer.parseInt(scanner.nextLine());
143143
int weightCnt = Integer.parseInt(scanner.nextLine());
144+
numberOfImages = imagesCnt;
145+
nonZeroAlpha.clear();
146+
for(int i = 0; i < weight.size(); i++)
147+
nonZeroAlpha.add(new HashSet<>());
144148
weight.clear();
145149
for (int i = 0; i < Label.values().length; i++)
146-
weight.add(new float[imagesCnt]);
150+
weight.add(new double[imagesCnt]);
147151

148152
for (int i = 0; i < weightCnt; i++)
149153
{
150154
int weightInd = Integer.parseInt(scanner.nextLine().trim().split(" ")[1]);
151155
for (int j = 0; j < imagesCnt; j++)
152156
{
153-
float weightParameter = Float.parseFloat(scanner.next());
157+
double weightParameter = Double.parseDouble(scanner.next());
154158
weight.get(weightInd)[j] = weightParameter;
159+
if (Math.abs(weightParameter - 0) > 1e-6)
160+
nonZeroAlpha.get(i).add(j);
155161
}
156162
scanner.nextLine();
157163
}
158164
System.out.println("salam");
165+
scanner.close();
166+
167+
maxImageIndex = numberOfImages - 1;
159168
}
160169
public KernelizedPerceptron(String trainData) {
161170
super();

0 commit comments

Comments
 (0)