-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_clustering.m
63 lines (52 loc) · 2.47 KB
/
run_clustering.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
%This code will run a simple example, showing how our graph-Laplacian based semi-supervised clustering works on circles discs.
clear all
close all
nsamples = [200,400,600]; %number of samples in each class
radial_centers = [0, 1.5, 3]; %Each classes radial center
decay_length = [0.25, 0.25, 0.3]; %The decay length of each class (how far the points in a single class will scatter)
label_certainty = 10; %Determines how certain we are of each observed label, higher is more certain.
nlabels_pr_class = 6; %selects the number of initial labels present in each class
nn = 29; %Number of nearest neighbours
alpha = 100; %hyper parameter that determines the relative strength between the two terms. Suggested values are in the region (1 - 1000)
beta = 1e-3; %hyper parameter that stabilizes the often ill-conditioned matrix we need to cluster. Suggested values are in the region (1e-2 - 1e-8)
%%
addpath('./src')
%Generate dataset
[X,labels_true] = generate_gaussian_circles(nsamples,radial_centers,decay_length);
%Randomly select starting labels
[labels,idx_selected] = select_starting_labels(labels_true,nlabels_pr_class);
nc = length(nsamples);
n = length(labels);
%Create the observed label pseudo-probabilities
Yobs = convert_labels_to_pseudo_probabilities(labels,nc,label_certainty);
Ytrue = convert_labels_to_pseudo_probabilities(labels_true,nc,label_certainty);
%Shows the data
fig1 =figure('Position',[0,0,1400,600])
subplot(1,2,1)
scatter(X(:,1),X(:,2),200,convert_pseudo_probability_to_probability(Yobs),'.')
title('Input data')
subplot(1,2,2)
scatter(X(:,1),X(:,2),200,convert_pseudo_probability_to_probability(Ytrue),'.')
title('True labels')
%Create the associated diagonal weight matrix
w = zeros(n,1);
w(idx_selected) = 1;
W = spdiags(w,0,n,n);
%Generates the graph-laplacian
[A,dd] = getAdjacencyMatrix(X,nn);
epsilon = median(dd(:));
[L,~,~] = getGraphLaplacian(X,A,epsilon);
%Clusters the data
Y = SSL_clustering(L,alpha,beta,Yobs,W);
%Plots the result
fig2 = figure('Position',[0,0,1400,600])
subplot(1,2,1)
scatter(X(:,1),X(:,2),200,convert_pseudo_probability_to_probability(Y),'.')
title('Label probabilities')
subplot(1,2,2)
[~,label_pred] = max(Y,[],2);
U_pred_boost = convert_pseudo_probability_to_probability(convert_labels_to_pseudo_probabilities(label_pred,nc,label_certainty));
scatter(X(:,1),X(:,2),200,U_pred_boost,'.')
title('Labels predicted')
print(fig1,'fig1','-dpng')
print(fig2,'fig2','-dpng')