Skip to content

Commit

Permalink
Additional Synthetic Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
dom-lee committed Dec 12, 2022
1 parent c1a77b2 commit b4f8d88
Show file tree
Hide file tree
Showing 18 changed files with 1,072 additions and 279 deletions.
47 changes: 38 additions & 9 deletions experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,15 @@ def experiment_instance(d, nn, q):
############################################
class SBMJobRunner(Process):

def __init__(self, k, n, prob_p, prob_q, queue, num_runs=1, use_grid=False, **kwargs):
def __init__(self, k, n, prob_p, prob_q, queue, num_runs=1, use_grid=False,
use_complete=False, unequal_cluster=False, **kwargs):
super(SBMJobRunner, self).__init__(**kwargs)
self.k, self.n, self.prob_p, self.prob_q = k, n, prob_p, prob_q
self.queue = queue
self.num_runs = num_runs
self.use_grid = use_grid
self.use_complete = use_complete
self.unequal_cluster = unequal_cluster
self.d = self.k
if use_grid:
self.k = self.d * self.d
Expand All @@ -156,8 +159,14 @@ def run(self) -> None:
total_rand_scores = [0] * self.k

for run_no in range(self.num_runs):
if self.use_grid:
if self.use_grid and self.unequal_cluster:
dataset = pysc.datasets.SbmUnequalGridDataset(d=self.d, n=self.n, p=self.prob_p, q=self.prob_q)
elif self.use_grid:
dataset = pysc.datasets.SBMGridDataset(d=self.d, n=self.n, p=self.prob_p, q=self.prob_q)
elif self.use_complete:
dataset = pysc.datasets.SbmCompleteDataset(k=self.k, n=self.n, p=self.prob_p, q=self.prob_q)
elif self.unequal_cluster:
dataset = pysc.datasets.SbmUnequalCycleDataset(k=self.k, n=self.n, p=self.prob_p, q=self.prob_q)
else:
dataset = pysc.datasets.SbmCycleDataset(k=self.k, n=self.n, p=self.prob_p, q=self.prob_q)
logger.info(f"Starting experiment: {dataset}, run number {run_no}")
Expand All @@ -184,24 +193,34 @@ def run(self) -> None:
self.queue.put(None)


def run_sbm_experiment(n, k, prob_p, use_grid=False):
def run_sbm_experiment(n, k, prob_p, use_grid=False, use_complete=False,
unequal_cluster=False):
logger.info(f"Running experiment with SBM data.")

# For each set of SBM parameters, run 10 times.
num_runs = 10

# Start all of the sub-processes to do the clustering with different numbers of eigenvalues
processes = []
if use_grid:
if use_grid and unequal_cluster:
results_filename = "results/sbm/grid_unequal_results.csv"
elif use_grid:
results_filename = "results/sbm/grid_results.csv"
elif use_complete:
results_filename = "results/sbm/complete_results.csv"
elif unequal_cluster:
results_filename = "results/sbm/cycle_unequal_results.csv"
else:
results_filename = "results/sbm/cycle_results.csv"

with open(results_filename, 'w') as fout:
fout.write("k, n, p, q, poverq, eigenvectors, conductance, rand\n")
fout.flush()
for prob_q in numpy.linspace(prob_p / 10, prob_p, num=10):
q = Queue()
p = SBMJobRunner(k, n, prob_p, prob_q, q, num_runs=num_runs, use_grid=use_grid)
p = SBMJobRunner(k, n, prob_p, prob_q, q, num_runs=num_runs,
use_grid=use_grid, use_complete=use_complete,
unequal_cluster=unequal_cluster)
p.start()
processes.append(p)

Expand Down Expand Up @@ -395,19 +414,29 @@ def run_bsds_experiment(image_id=None):

def parse_args():
parser = argparse.ArgumentParser(description='Run the experiments.')
parser.add_argument('experiment', type=str, choices=['cycle', 'grid', 'mnist', 'usps', 'bsds'],
parser.add_argument('experiment', type=str,
choices=['cycle', 'grid', 'cycle_unequal', 'grid_unequal',
'complete', 'mnist', 'usps', 'bsds'],
help="which experiment to perform")
parser.add_argument('bsds_image', type=str, nargs='?', help="(optional) the BSDS ID of a single BSDS image file to segment")
parser.add_argument('bsds_image', type=str, nargs='?',
help="(optional) the BSDS ID of a single " \
"BSDS image file to segment")
return parser.parse_args()


def main():
args = parse_args()

if args.experiment == 'cycle':
run_sbm_experiment(1000, 10, 0.01)
run_sbm_experiment(1000, 5, 0.01)
elif args.experiment == 'cycle_unequal':
run_sbm_experiment([100, 120, 140, 160, 180, 200], 6, 0.01, unequal_cluster=True)
elif args.experiment == 'grid':
run_sbm_experiment(1000, 4, 0.01, use_grid=True)
run_sbm_experiment(1000, 5, 0.01, use_grid=True)
elif args.experiment == 'grid_unequal':
run_sbm_experiment([100 + 20 * i for i in range(4*4)], 4, 0.01, use_grid=True, unequal_cluster=True)
elif args.experiment == 'complete':
run_sbm_experiment(100, 5, 0.2, use_complete=True)
elif args.experiment == 'mnist':
run_mnist_experiment()
elif args.experiment == 'usps':
Expand Down
33 changes: 33 additions & 0 deletions plot_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import argparse
import matplotlib.pyplot as plt
import pysc.datasets

def parse_args():
parser = argparse.ArgumentParser(description='Run the experiments.')
parser.add_argument('experiment', type=str,
choices=['cycle', 'grid', 'cycle_unequal', 'grid_unequal',
'complete', 'mnist', 'usps', 'bsds'],
help="which experiment to perform")
parser.add_argument('bsds_image', type=str, nargs='?',
help="(optional) the BSDS ID of a single " \
"BSDS image file to segment")
return parser.parse_args()


def main():
args = parse_args()

if args.experiment == 'cycle':
dataset = pysc.datasets.SbmCycleDataset(k=10, n=1000, p=0.01, q=0.001)
elif args.experiment == 'grid':
dataset = pysc.datasets.SBMGridDataset(d=4, n=1000, p=0.01, q=0.001)
elif args.experiment == 'complete':
dataset = pysc.datasets.SbmCompleteDataset(k=5, n=100, p=0.2, q=0.01)

# Draw
dataset.graph.draw()
plt.show()


if __name__ == "__main__":
main()
103 changes: 103 additions & 0 deletions pysc/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,109 @@ def __repr__(self):
return self.__str__()


class SbmCompleteDataset(Dataset):

def __init__(self, *args, k=6, n=50, p=0.3, q=0.05, **kwargs):
self.k, self.n, self.p, self.q = k, n, p, q
super(SbmCompleteDataset, self).__init__(self, *args, num_data_points=(n * k), **kwargs)

def load_data(self, data_file):
# The SBM dataset has no data
pass

def load_graph(self, graph_file=None, graph_type="knn10"):
# Generate the graph from the sbm
logger.info(f"Generating {self} graph from sbm...")
prob_mat = self.p * sp.sparse.eye(self.k) + \
self.q * sgtl.graph.complete_graph(self.k).adjacency_matrix()
self.graph = sgtl.random.sbm_equal_clusters(self.n * self.k, self.k,
prob_mat.toarray())

def load_gt_clusters(self, gt_clusters_file):
logger.info(f"Loading GT clusters for {self}...")

# We can just generate the ground truth clusters as needed
self.gt_clusters = [list(range(i * self.n, (i * self.n) + self.n)) for i in range(self.k)]
self.gt_labels = []
for cluster in range(self.k):
for _ in range(self.n):
self.gt_labels.append(cluster)

def __str__(self):
return f"sbmComplete({self.k}, {self.n}, {self.p}, {self.q})"

def __repr__(self):
return self.__str__()


class SbmUnequalCycleDataset(Dataset):

def __init__(self, *args, k=6, n=[10, 20, 30, 40, 50, 60], p=0.3, q=0.05, **kwargs):
self.k, self.n, self.p, self.q = k, n, p, q
super(SbmUnequalCycleDataset, self).__init__(self, *args, num_data_points=sum(n), **kwargs)

def load_data(self, data_file):
# The SBM dataset has no data
pass

def load_graph(self, graph_file=None, graph_type="knn10"):
# Generate the graph from the sbm
logger.info(f"Generating {self} graph from sbm...")
prob_mat = self.p * sp.sparse.eye(self.k) + \
self.q * sgtl.graph.cycle_graph(self.k).adjacency_matrix()
self.graph = sgtl.random.sbm(self.n, prob_mat.toarray())

def load_gt_clusters(self, gt_clusters_file):
logger.info(f"Loading GT clusters for {self}...")

# We can just generate the ground truth clusters as needed
self.gt_clusters = [list(range(sum(self.n[:i]), sum(self.n[:i+1]))) for i in range(self.k)]
self.gt_labels = []
for cluster in range(self.k):
for _ in range(self.n[cluster]):
self.gt_labels.append(cluster)

def __str__(self):
return f"sbmUnequalCycle({self.k}, {self.n}, {self.p}, {self.q})"

def __repr__(self):
return self.__str__()


class SbmUnequalGridDataset(Dataset):

def __init__(self, *args, d=3, n=[10, 20, 30, 40, 50, 60, 70, 80, 90], p=0.3, q=0.05, **kwargs):
self.d, self.n, self.p, self.q = d, n, p, q
super(SbmUnequalGridDataset, self).__init__(self, *args, num_data_points=sum(n), **kwargs)

def load_data(self, data_file):
# The SBM dataset has no data
pass

def load_graph(self, graph_file=None, graph_type="knn10"):
# Generate the graph from the sbm
logger.info(f"Generating {self} graph from sbm...")
prob_mat = self.p * sp.sparse.eye(self.d * self.d) + self.q * \
networkx.to_numpy_matrix(grid_graph((self.d, self.d)))
self.graph = sgtl.random.sbm(self.n, prob_mat.tolist())

def load_gt_clusters(self, gt_clusters_file):
logger.info(f"Loading GT clusters for {self}...")

# We can just generate the ground truth clusters as needed
self.gt_clusters = [list(range(sum(self.n[:i]), sum(self.n[:i+1]))) for
i in range(self.d * self.d)]
self.gt_labels = []
for cluster in range(self.d * self.d):
for _ in range(self.n[cluster]):
self.gt_labels.append(cluster)

def __str__(self):
return f"sbmUnequalGrid({self.d}, {self.n}, {self.p}, {self.q})"

def __repr__(self):
return self.__str__()

class BSDSDataset(Dataset):

def __init__(self, img_idx, *args, downsample_factor=None,
Expand Down
1 change: 1 addition & 0 deletions pysc/objfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def apply(graph: sgtl.Graph, clusters: List[List[int]]) -> float:
conductances.append(graph.conductance(cluster))
else:
conductances.append(1)

except ZeroDivisionError:
# In the case of a zero division error, it must be that one of the clusters is empty, return 1
return 1
Expand Down
18 changes: 9 additions & 9 deletions results/mnist/results.csv
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
k, d, eigenvectors, rand
3, None, 2, 0.28905224553755665
3, None, 3, 0.3849156686133719
3, None, 4, 0.4875666824315091
3, None, 5, 0.4894281122170645
3, None, 6, 0.5017134535797279
3, None, 7, 0.6411298401789535
3, None, 8, 0.5741322527508748
3, None, 9, 0.5644375504396668
3, None, 10, 0.5970835200708858
3, None, 2, 0.28915961063265805
3, None, 3, 0.38488138113558296
3, None, 4, 0.5247131156026261
3, None, 5, 0.4900618290421381
3, None, 6, 0.5212457223862366
3, None, 7, 0.6408013718318064
3, None, 8, 0.5749934584340691
3, None, 9, 0.5637025932464547
3, None, 10, 0.5970996153925824
101 changes: 101 additions & 0 deletions results/sbm/complete_results_10_1000_0.01.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
k, n, p, q, poverq, eigenvectors, conductance, rand
10, 1000, 0.01, 0.001, 10.0, 1, 0.9942154716681323, 0.794531797179718
10, 1000, 0.01, 0.001, 10.0, 2, 0.8390756147625288, 0.8258927192719272
10, 1000, 0.01, 0.001, 10.0, 3, 0.8059145106188772, 0.8692812641264126
10, 1000, 0.01, 0.001, 10.0, 4, 0.790484538575552, 0.9066967176717672
10, 1000, 0.01, 0.001, 10.0, 5, 0.7059770412832592, 0.9403317111711171
10, 1000, 0.01, 0.001, 10.0, 6, 0.6549282178164204, 0.9602805100510052
10, 1000, 0.01, 0.001, 10.0, 7, 0.5878097713170346, 0.97362501650165
10, 1000, 0.01, 0.001, 10.0, 8, 0.5678457948468523, 0.9807183338333834
10, 1000, 0.01, 0.001, 10.0, 9, 0.5419374342007882, 0.9881687848784877
10, 1000, 0.01, 0.001, 10.0, 10, 0.47965878351067215, 0.9950394939493948
10, 1000, 0.01, 0.002, 5.0, 1, 0.9910742917481083, 0.7985732353235324
10, 1000, 0.01, 0.002, 5.0, 2, 0.8684235452517468, 0.8128586818681869
10, 1000, 0.01, 0.002, 5.0, 3, 0.8364365116599203, 0.8423175357535755
10, 1000, 0.01, 0.002, 5.0, 4, 0.8352890746782334, 0.867986694669467
10, 1000, 0.01, 0.002, 5.0, 5, 0.8333001800663004, 0.8882074707470746
10, 1000, 0.01, 0.002, 5.0, 6, 0.8052407041631202, 0.908023518351835
10, 1000, 0.01, 0.002, 5.0, 7, 0.7777342945348563, 0.9239243564356434
10, 1000, 0.01, 0.002, 5.0, 8, 0.7244209151034353, 0.9397077667766777
10, 1000, 0.01, 0.002, 5.0, 9, 0.7058196493657511, 0.9509872567256725
10, 1000, 0.01, 0.002, 5.0, 10, 0.6551385425380223, 0.9635732213221322
10, 1000, 0.01, 0.003, 3.3333333333333335, 1, 0.989974881492247, 0.7997549674967497
10, 1000, 0.01, 0.003, 3.3333333333333335, 2, 0.9025054539020341, 0.8065420442044203
10, 1000, 0.01, 0.003, 3.3333333333333335, 3, 0.8384928281644672, 0.8218328452845285
10, 1000, 0.01, 0.003, 3.3333333333333335, 4, 0.8292521967374128, 0.8326183778377836
10, 1000, 0.01, 0.003, 3.3333333333333335, 5, 0.8246221125050056, 0.841405494549455
10, 1000, 0.01, 0.003, 3.3333333333333335, 6, 0.8254184173619763, 0.8489595659565957
10, 1000, 0.01, 0.003, 3.3333333333333335, 7, 0.8145191624024862, 0.8567835343534353
10, 1000, 0.01, 0.003, 3.3333333333333335, 8, 0.8066932801515619, 0.8627620002000201
10, 1000, 0.01, 0.003, 3.3333333333333335, 9, 0.7948623957068361, 0.8688178297829783
10, 1000, 0.01, 0.003, 3.3333333333333335, 10, 0.7699172351489144, 0.8732717231723173
10, 1000, 0.01, 0.004, 2.5, 1, 0.9813003428308551, 0.8001861706170619
10, 1000, 0.01, 0.004, 2.5, 2, 0.9147702007166328, 0.8011065906590659
10, 1000, 0.01, 0.004, 2.5, 3, 0.8580793660647095, 0.8120321272127212
10, 1000, 0.01, 0.004, 2.5, 4, 0.8369961078409556, 0.8191220342034203
10, 1000, 0.01, 0.004, 2.5, 5, 0.8287189500792242, 0.821270395039504
10, 1000, 0.01, 0.004, 2.5, 6, 0.8171998737515, 0.8217899409940994
10, 1000, 0.01, 0.004, 2.5, 7, 0.820730000938657, 0.8214989418941894
10, 1000, 0.01, 0.004, 2.5, 8, 0.8192029637414835, 0.8214918511851185
10, 1000, 0.01, 0.004, 2.5, 9, 0.8157383522569104, 0.8217383958395839
10, 1000, 0.01, 0.004, 2.5, 10, 0.8069352042981934, 0.8218737393739375
10, 1000, 0.01, 0.005, 2.0, 1, 0.9856655355960816, 0.7972521572157215
10, 1000, 0.01, 0.005, 2.0, 2, 0.9218485635157261, 0.8017634843484348
10, 1000, 0.01, 0.005, 2.0, 3, 0.8630456559867996, 0.8113739333933394
10, 1000, 0.01, 0.005, 2.0, 4, 0.8419012573068023, 0.8180734793479347
10, 1000, 0.01, 0.005, 2.0, 5, 0.8348959656064923, 0.8202091209120912
10, 1000, 0.01, 0.005, 2.0, 6, 0.8272734952015922, 0.8203597299729972
10, 1000, 0.01, 0.005, 2.0, 7, 0.8294012357455939, 0.8200889468946894
10, 1000, 0.01, 0.005, 2.0, 8, 0.8314667509734676, 0.8199072827282728
10, 1000, 0.01, 0.005, 2.0, 9, 0.8252549858888429, 0.8202816801680168
10, 1000, 0.01, 0.005, 2.0, 10, 0.8175613338955543, 0.8204440364036403
10, 1000, 0.01, 0.006, 1.6666666666666667, 1, 0.9852832434572317, 0.8003018361836183
10, 1000, 0.01, 0.006, 1.6666666666666667, 2, 0.9286957766444746, 0.801378305830583
10, 1000, 0.01, 0.006, 1.6666666666666667, 3, 0.8708289470333181, 0.8109790019001899
10, 1000, 0.01, 0.006, 1.6666666666666667, 4, 0.8488895998965269, 0.8177930473047306
10, 1000, 0.01, 0.006, 1.6666666666666667, 5, 0.8391047465051746, 0.8200694889488949
10, 1000, 0.01, 0.006, 1.6666666666666667, 6, 0.8336463665227528, 0.8201744614461447
10, 1000, 0.01, 0.006, 1.6666666666666667, 7, 0.8364535856062257, 0.819813201320132
10, 1000, 0.01, 0.006, 1.6666666666666667, 8, 0.8374659771392506, 0.8199127732773277
10, 1000, 0.01, 0.006, 1.6666666666666667, 9, 0.8329609601516375, 0.8200748294829484
10, 1000, 0.01, 0.006, 1.6666666666666667, 10, 0.8251690719677836, 0.8202274787478748
10, 1000, 0.01, 0.007, 1.4285714285714286, 1, 0.9836181004065863, 0.8010404900490048
10, 1000, 0.01, 0.007, 1.4285714285714286, 2, 0.9367179265186314, 0.800498699869987
10, 1000, 0.01, 0.007, 1.4285714285714286, 3, 0.8787722110262937, 0.8112755215521554
10, 1000, 0.01, 0.007, 1.4285714285714286, 4, 0.8539647906353048, 0.8178326272627261
10, 1000, 0.01, 0.007, 1.4285714285714286, 5, 0.8449952121582667, 0.8199714771477149
10, 1000, 0.01, 0.007, 1.4285714285714286, 6, 0.8397336248714955, 0.8200684388438845
10, 1000, 0.01, 0.007, 1.4285714285714286, 7, 0.8413266391988177, 0.8197346674667468
10, 1000, 0.01, 0.007, 1.4285714285714286, 8, 0.8430334817480558, 0.8198148134813483
10, 1000, 0.01, 0.007, 1.4285714285714286, 9, 0.8409994187576462, 0.8199277047704768
10, 1000, 0.01, 0.007, 1.4285714285714286, 10, 0.8309083173223734, 0.8201494369436941
10, 1000, 0.01, 0.008, 1.25, 1, 0.9834373742288335, 0.7992344614461445
10, 1000, 0.01, 0.008, 1.25, 2, 0.9366916735235854, 0.8013412541254125
10, 1000, 0.01, 0.008, 1.25, 3, 0.8808967793986788, 0.8108300830083008
10, 1000, 0.01, 0.008, 1.25, 4, 0.8575000048726084, 0.8176582438243825
10, 1000, 0.01, 0.008, 1.25, 5, 0.8529635699757586, 0.8195776477647765
10, 1000, 0.01, 0.008, 1.25, 6, 0.8429547800706441, 0.8200764436443644
10, 1000, 0.01, 0.008, 1.25, 7, 0.8474861513516725, 0.8196706290629063
10, 1000, 0.01, 0.008, 1.25, 8, 0.849582442067484, 0.8196836163616362
10, 1000, 0.01, 0.008, 1.25, 9, 0.8446971197765347, 0.819906320632063
10, 1000, 0.01, 0.008, 1.25, 10, 0.8356768337964825, 0.8200890689068906
10, 1000, 0.01, 0.009000000000000001, 1.111111111111111, 1, 0.9843414718906704, 0.7975807860786079
10, 1000, 0.01, 0.009000000000000001, 1.111111111111111, 2, 0.9414145666670704, 0.8005392999299931
10, 1000, 0.01, 0.009000000000000001, 1.111111111111111, 3, 0.8826140551154822, 0.811200702070207
10, 1000, 0.01, 0.009000000000000001, 1.111111111111111, 4, 0.8636672617673282, 0.8175581778177816
10, 1000, 0.01, 0.009000000000000001, 1.111111111111111, 5, 0.8541005043235017, 0.8198287268726873
10, 1000, 0.01, 0.009000000000000001, 1.111111111111111, 6, 0.8485630110861436, 0.8200354455445545
10, 1000, 0.01, 0.009000000000000001, 1.111111111111111, 7, 0.8502890316156237, 0.8197364556455646
10, 1000, 0.01, 0.009000000000000001, 1.111111111111111, 8, 0.8512880565513952, 0.8197437203720372
10, 1000, 0.01, 0.009000000000000001, 1.111111111111111, 9, 0.847531594498587, 0.8199020662066205
10, 1000, 0.01, 0.009000000000000001, 1.111111111111111, 10, 0.8415487882109837, 0.8200232383238324
10, 1000, 0.01, 0.01, 1.0, 1, 0.9861505023968824, 0.7982626842684268
10, 1000, 0.01, 0.01, 1.0, 2, 0.9417969325190519, 0.8011308990899091
10, 1000, 0.01, 0.01, 1.0, 3, 0.8870422499898837, 0.8108628302830283
10, 1000, 0.01, 0.01, 1.0, 4, 0.8636831786908801, 0.8176137373737374
10, 1000, 0.01, 0.01, 1.0, 5, 0.8587399682120596, 0.8197107130713072
10, 1000, 0.01, 0.01, 1.0, 6, 0.8523392685215313, 0.819990529052905
10, 1000, 0.01, 0.01, 1.0, 7, 0.8543535629985783, 0.8196426762676268
10, 1000, 0.01, 0.01, 1.0, 8, 0.8534573514543343, 0.819779511951195
10, 1000, 0.01, 0.01, 1.0, 9, 0.8539698910714197, 0.8198205140514052
10, 1000, 0.01, 0.01, 1.0, 10, 0.8435235048065939, 0.8200486848684869
Loading

0 comments on commit b4f8d88

Please sign in to comment.