Skip to content

Commit b355b38

Browse files
committed
Mainly changes in style:
correcting typos, enforcing conventions.
1 parent 7fea82c commit b355b38

File tree

5 files changed

+66
-69
lines changed

5 files changed

+66
-69
lines changed

DiffPrivate_FedLearning.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def run_differentially_private_federated_averaging(loss, train_op, eval_correct,
1313
label_placeholder, privacy_agent=None, b=10, e=4,
1414
record_privacy=True, m=0, sigma=0, eps=8, save_dir=None,
1515
log_dir=None, max_comm_rounds=3000, gm=True,
16-
saver_func=create_save_dir):
16+
saver_func=create_save_dir, save_params=False):
1717

1818
"""
1919
This function will simulate a federated learning setting and enable differential privacy tracking. It will detect
@@ -65,6 +65,7 @@ def run_differentially_private_federated_averaging(loss, train_op, eval_correct,
6565
:param gm: Whether to use a Gaussian Mechanism or not.
6666
:param saver_func: A function that specifies where and how to save progress: Note that the usual tensorflow
6767
tracking will not work
68+
:param save_params: save all weights_throughout training.
6869
6970
:return:
7071
@@ -277,4 +278,7 @@ def run_differentially_private_federated_averaging(loss, train_op, eval_correct,
277278
# PRINT the progress and stage of affairs.
278279
print(' - Epsilon-Delta Privacy:' + str([FLAGS.eps, delta]))
279280

281+
if save_params:
282+
weights_accountant.save_params(save_dir)
283+
280284
return [], [], []

Helper_Functions.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ def __init__(self,sess,model,Sigma, real_round):
7676
self.num_weights = len(self.Weights)
7777
self.round = real_round
7878

79+
def save_params(self,save_dir):
80+
filehandler = open(save_dir + '/Wweights_accountant_round_'+self.round + '.pkl', "wb")
81+
pickle.dump(self, filehandler)
82+
filehandler.close()
83+
7984
def allocate(self, sess):
8085

8186
self.Weights = [np.concatenate((self.Weights[i], np.expand_dims(sess.run(tf.trainable_variables()[i]), -1)), -1)
@@ -146,7 +151,6 @@ def Update_via_GaussianMechanism(self, sess, Acc, FLAGS, Computed_deltas):
146151
delta = Computed_deltas[self.round]
147152
return New_model, delta
148153

149-
150154
def create_save_dir(FLAGS):
151155
'''
152156
:return: Returns a path that is used to store training progress; the path also identifies the chosen setup uniquely.

MNIST_reader.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def read(dataset = "training", path = "."):
1919
else:
2020
raise ValueError, "dataset must be 'testing' or 'training'"
2121

22+
print(fname_lbl)
23+
2224
# Load everything in some numpy arrays
2325
with open(fname_lbl, 'rb') as flbl:
2426
magic, num = struct.unpack(">II", flbl.read(8))
@@ -36,10 +38,10 @@ def read(dataset = "training", path = "."):
3638
return img, lbl
3739

3840

39-
def get_data():
41+
def get_data(d):
4042
# load the data
41-
x_train, y_train = read('training', os.getcwd() + '/MNIST_original')
42-
x_test, y_test = read('testing', os.getcwd() + '/MNIST_original')
43+
x_train, y_train = read('training', d + '/MNIST_original')
44+
x_test, y_test = read('testing', d + '/MNIST_original')
4345

4446
# create validation set
4547
x_vali = list(x_train[50000:].astype(float))
@@ -63,6 +65,6 @@ def get_data():
6365

6466
class Data:
6567
def __init__(self, save_dir, n):
66-
raw_directory = save_dir + '/DATA/'
67-
self.client_set = pickle.load(open(raw_directory + 'clients/' + str(n) + '_clients.pkl', 'rb'))
68-
self.sorted_x_train, self.sorted_y_train, self.x_vali, self.y_vali, self.x_test, self.y_test = get_data()
68+
raw_directory = save_dir + '/DATA'
69+
self.client_set = pickle.load(open(raw_directory + '/clients/' + str(n) + '_clients.pkl', 'rb'))
70+
self.sorted_x_train, self.sorted_y_train, self.x_vali, self.y_vali, self.x_test, self.y_test = get_data(save_dir)

RUNME.sh

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/bin/sh
2+
STRING="Downloading the MNIST-data set and creating clients"
3+
echo $STRING
4+
eval cd DiffPrivate_FedLearning
5+
eval mkdir MNIST_original
6+
eval cd MNIST_original
7+
eval curl -O "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
8+
eval curl -O "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
9+
eval curl -O "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
10+
eval curl -O "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
11+
eval gunzip train-images-idx3-ubyte.gz
12+
eval gunzip train-labels-idx1-ubyte.gz
13+
eval gunzip t10k-images-idx3-ubyte.gz
14+
eval gunzip t10k-labels-idx1-ubyte.gz
15+
eval cd ..
16+
eval python Create_clients.py
17+
STRING2="You can now run differentially private federated learning on the MNIST data set. Type python sample.py —-h for help"
18+
echo $STRING2
19+
STRING3="An example: …python sample.py —-N 100… would run differentially private federated learning on 100 clients for a privacy budget of (epsilon = 8, delta = 0.001)"
20+
echo $STRING3
21+
STINRG4="For more information on how to use the the functions please refer to their documentation"
22+
echo $STRING4

sample.py

+26-61
Original file line numberDiff line numberDiff line change
@@ -3,71 +3,48 @@
33
import os
44
from DiffPrivate_FedLearning import run_differentially_private_federated_averaging
55
from MNIST_reader import Data
6+
import argparse
7+
import sys
68

7-
# Specs for the model that we would like to train in differentially private federated fashion:
8-
hidden1 = 600
9-
hidden2 = 100
109

11-
# Specs for the differentially private federated fashion learning process.
12-
N = 100
13-
Batches = 10
14-
save_dir = os.getcwd()
10+
def sample(N, b,e,m, sigma, eps, save_dir, log_dir):
1511

16-
# A data object that already satisfies client structure and has the following attributes:
17-
# DATA.data_set : A list of labeld training examples.
18-
# DATA.client_set : A
19-
DATA = Data(save_dir, N)
12+
# Specs for the model that we would like to train in differentially private federated fashion:
13+
hidden1 = 600
14+
hidden2 = 100
2015

21-
with tf.Graph().as_default():
16+
# Specs for the differentially private federated fashion learning process.
2217

23-
# Building the model that we would like to train in differentially private federated fashion.
24-
# We will need the tensorflow training operation for that model, its loss and an evaluation method:
18+
# A data object that already satisfies client structure and has the following attributes:
19+
# DATA.data_set : A list of labeld training examples.
20+
# DATA.client_set : A
21+
DATA = Data(save_dir, N)
2522

26-
train_op, eval_correct, loss, data_placeholder, labels_placeholder = mnist.mnist_fully_connected_model(Batches, hidden1, hidden2)
23+
with tf.Graph().as_default():
2724

28-
Accuracy_accountant, Delta_accountant, model = \
29-
run_differentially_private_federated_averaging(loss, train_op, eval_correct, DATA, data_placeholder, labels_placeholder)
30-
31-
'''
32-
def main(_):
33-
data = Data(FLAGS.save_dir, FLAGS.n)
34-
train_op, eval_correct, loss = mnist_inference.mnist_fully_connected_model()
35-
run_differentially_private_federated_averaging(loss, train_op, eval_correct, data)
25+
# Building the model that we would like to train in differentially private federated fashion.
26+
# We will need the tensorflow training operation for that model, its loss and an evaluation method:
3627

28+
train_op, eval_correct, loss, data_placeholder, labels_placeholder = mnist.mnist_fully_connected_model(b, hidden1, hidden2)
3729

38-
class Flag:
39-
def __init__(self, n, b, e, record_privacy, m, sigma, eps, save_dir, log_dir, max_comm_rounds, gm, PrivAgent):
40-
if not save_dir:
41-
save_dir = os.getcwd()
42-
if not log_dir:
43-
log_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/logs/fully_connected_feed')
44-
if tf.gfile.Exists(log_dir):
45-
tf.gfile.DeleteRecursively(log_dir)
46-
tf.gfile.MakeDirs(log_dir)
47-
self.n = n
48-
self.sigma = sigma
49-
self.eps = eps
50-
self.m = m
51-
self.b = b
52-
self.e = e
53-
self.record_privacy = record_privacy
54-
self.save_dir = save_dir
55-
self.log_dir = log_dir
56-
self.max_comm_rounds = max_comm_rounds
57-
self.gm = gm
58-
self.PrivAgentName = PrivAgent.Name
30+
Accuracy_accountant, Delta_accountant, model = \
31+
run_differentially_private_federated_averaging(loss, train_op, eval_correct, DATA, data_placeholder,
32+
labels_placeholder, b=b, e=e,m=m, sigma=sigma, eps=eps,
33+
save_dir=save_dir, log_dir=log_dir)
5934

35+
def main(_):
36+
sample(N=FLAGS.N, b=FLAGS.b, e=FLAGS.e,m=FLAGS.m, sigma=FLAGS.sigma, eps=FLAGS.eps, save_dir=None, log_dir=FLAGS.log_dir)
6037

6138
if __name__ == '__main__':
6239
parser = argparse.ArgumentParser()
6340
parser.add_argument(
64-
'--PrivAgentName',
41+
'--save_dir',
6542
type=str,
66-
default='default_Priv_Agent',
67-
help='Sets the name of the used Privacy agent'
43+
default=os.getcwd(),
44+
help='directory to store progress'
6845
)
6946
parser.add_argument(
70-
'--n',
47+
'--N',
7148
type=int,
7249
default=100,
7350
help='Total Number of clients participating'
@@ -102,12 +79,6 @@ def __init__(self, n, b, e, record_privacy, m, sigma, eps, save_dir, log_dir, ma
10279
default=4,
10380
help='Epochs per client'
10481
)
105-
parser.add_argument(
106-
'--record_privacy',
107-
type=bool,
108-
default=True,
109-
help='Epochs per client'
110-
)
11182
parser.add_argument(
11283
'--save_dir',
11384
type=str,
@@ -121,12 +92,6 @@ def __init__(self, n, b, e, record_privacy, m, sigma, eps, save_dir, log_dir, ma
12192
'tensorflow/mnist/logs/fully_connected_feed'),
12293
help='Directory to put the log data.'
12394
)
124-
parser.add_argument(
125-
'--max_comm_rounds',
126-
type=int,
127-
default=3000,
128-
help='Maximum number of communication rounds'
129-
)
13095
FLAGS, unparsed = parser.parse_known_args()
13196
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
132-
'''
97+

0 commit comments

Comments
 (0)