3
3
import os
4
4
from DiffPrivate_FedLearning import run_differentially_private_federated_averaging
5
5
from MNIST_reader import Data
6
+ import argparse
7
+ import sys
6
8
7
- # Specs for the model that we would like to train in differentially private federated fashion:
8
- hidden1 = 600
9
- hidden2 = 100
10
9
11
- # Specs for the differentially private federated fashion learning process.
12
- N = 100
13
- Batches = 10
14
- save_dir = os .getcwd ()
10
+ def sample (N , b ,e ,m , sigma , eps , save_dir , log_dir ):
15
11
16
- # A data object that already satisfies client structure and has the following attributes:
17
- # DATA.data_set : A list of labeld training examples.
18
- # DATA.client_set : A
19
- DATA = Data (save_dir , N )
12
+ # Specs for the model that we would like to train in differentially private federated fashion:
13
+ hidden1 = 600
14
+ hidden2 = 100
20
15
21
- with tf . Graph (). as_default ():
16
+ # Specs for the differentially private federated fashion learning process.
22
17
23
- # Building the model that we would like to train in differentially private federated fashion.
24
- # We will need the tensorflow training operation for that model, its loss and an evaluation method:
18
+ # A data object that already satisfies client structure and has the following attributes:
19
+ # DATA.data_set : A list of labeld training examples.
20
+ # DATA.client_set : A
21
+ DATA = Data (save_dir , N )
25
22
26
- train_op , eval_correct , loss , data_placeholder , labels_placeholder = mnist . mnist_fully_connected_model ( Batches , hidden1 , hidden2 )
23
+ with tf . Graph (). as_default ():
27
24
28
- Accuracy_accountant , Delta_accountant , model = \
29
- run_differentially_private_federated_averaging (loss , train_op , eval_correct , DATA , data_placeholder , labels_placeholder )
30
-
31
- '''
32
- def main(_):
33
- data = Data(FLAGS.save_dir, FLAGS.n)
34
- train_op, eval_correct, loss = mnist_inference.mnist_fully_connected_model()
35
- run_differentially_private_federated_averaging(loss, train_op, eval_correct, data)
25
+ # Building the model that we would like to train in differentially private federated fashion.
26
+ # We will need the tensorflow training operation for that model, its loss and an evaluation method:
36
27
28
+ train_op , eval_correct , loss , data_placeholder , labels_placeholder = mnist .mnist_fully_connected_model (b , hidden1 , hidden2 )
37
29
38
- class Flag:
39
- def __init__(self, n, b, e, record_privacy, m, sigma, eps, save_dir, log_dir, max_comm_rounds, gm, PrivAgent):
40
- if not save_dir:
41
- save_dir = os.getcwd()
42
- if not log_dir:
43
- log_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/logs/fully_connected_feed')
44
- if tf.gfile.Exists(log_dir):
45
- tf.gfile.DeleteRecursively(log_dir)
46
- tf.gfile.MakeDirs(log_dir)
47
- self.n = n
48
- self.sigma = sigma
49
- self.eps = eps
50
- self.m = m
51
- self.b = b
52
- self.e = e
53
- self.record_privacy = record_privacy
54
- self.save_dir = save_dir
55
- self.log_dir = log_dir
56
- self.max_comm_rounds = max_comm_rounds
57
- self.gm = gm
58
- self.PrivAgentName = PrivAgent.Name
30
+ Accuracy_accountant , Delta_accountant , model = \
31
+ run_differentially_private_federated_averaging (loss , train_op , eval_correct , DATA , data_placeholder ,
32
+ labels_placeholder , b = b , e = e ,m = m , sigma = sigma , eps = eps ,
33
+ save_dir = save_dir , log_dir = log_dir )
59
34
35
+ def main (_ ):
36
+ sample (N = FLAGS .N , b = FLAGS .b , e = FLAGS .e ,m = FLAGS .m , sigma = FLAGS .sigma , eps = FLAGS .eps , save_dir = None , log_dir = FLAGS .log_dir )
60
37
61
38
if __name__ == '__main__' :
62
39
parser = argparse .ArgumentParser ()
63
40
parser .add_argument (
64
- '--PrivAgentName ',
41
+ '--save_dir ' ,
65
42
type = str ,
66
- default='default_Priv_Agent' ,
67
- help='Sets the name of the used Privacy agent '
43
+ default = os . getcwd () ,
44
+ help = 'directory to store progress '
68
45
)
69
46
parser .add_argument (
70
- '--n ',
47
+ '--N ' ,
71
48
type = int ,
72
49
default = 100 ,
73
50
help = 'Total Number of clients participating'
@@ -102,12 +79,6 @@ def __init__(self, n, b, e, record_privacy, m, sigma, eps, save_dir, log_dir, ma
102
79
default = 4 ,
103
80
help = 'Epochs per client'
104
81
)
105
- parser.add_argument(
106
- '--record_privacy',
107
- type=bool,
108
- default=True,
109
- help='Epochs per client'
110
- )
111
82
parser .add_argument (
112
83
'--save_dir' ,
113
84
type = str ,
@@ -121,12 +92,6 @@ def __init__(self, n, b, e, record_privacy, m, sigma, eps, save_dir, log_dir, ma
121
92
'tensorflow/mnist/logs/fully_connected_feed' ),
122
93
help = 'Directory to put the log data.'
123
94
)
124
- parser.add_argument(
125
- '--max_comm_rounds',
126
- type=int,
127
- default=3000,
128
- help='Maximum number of communication rounds'
129
- )
130
95
FLAGS , unparsed = parser .parse_known_args ()
131
96
tf .app .run (main = main , argv = [sys .argv [0 ]] + unparsed )
132
- '''
97
+
0 commit comments