Skip to content

Commit f4d7131

Browse files
committed
- Pull out magic numbers to constants
Signed-off-by: Shailesh Pant <[email protected]>
1 parent 9c4442e commit f4d7131

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

openfl/component/aggregator/aggregator.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import List, Optional
1111

1212
import openfl.callbacks as callbacks_module
13+
from openfl.component.aggregator import constants
1314
from openfl.component.aggregator.straggler_handling import CutoffTimePolicy, StragglerPolicy
1415
from openfl.databases import PersistentTensorDB, TensorDB
1516
from openfl.interface.aggregation_functions import WeightedAverage
@@ -74,10 +75,10 @@ def __init__(
7475
assigner,
7576
use_delta_updates=True,
7677
straggler_handling_policy: StragglerPolicy = CutoffTimePolicy,
77-
rounds_to_train=256,
78+
rounds_to_train=constants.DEFAULT_ROUNDS_TO_TRAIN,
7879
single_col_cert_common_name=None,
7980
compression_pipeline=None,
80-
db_store_rounds=1,
81+
db_store_rounds=constants.DEFAULT_DB_STORE_ROUNDS,
8182
initial_tensor_dict=None,
8283
log_memory_usage=False,
8384
write_logs=False,
@@ -100,13 +101,13 @@ def __init__(
100101
assigner: Assigner object.
101102
straggler_handling_policy (optional): Straggler handling policy.
102103
rounds_to_train (int, optional): Number of rounds to train.
103-
Defaults to 256.
104+
Defaults to constants.DEFAULT_ROUNDS_TO_TRAIN.
104105
single_col_cert_common_name (str, optional): Common name for single
105106
collaborator certificate. Defaults to None.
106107
compression_pipeline (optional): Compression pipeline. Defaults to
107108
NoCompressionPipeline.
108109
db_store_rounds (int, optional): Rounds to store in TensorDB.
109-
Defaults to 1.
110+
Defaults to constants.DEFAULT_DB_STORE_ROUNDS.
110111
initial_tensor_dict (dict, optional): Initial tensor dictionary.
111112
callbacks: List of callbacks to be used during the experiment.
112113
"""
@@ -120,15 +121,17 @@ def __init__(
120121
"provide proper Public Key Infrastructure (PKI) security. "
121122
"Please use this mode with caution."
122123
)
123-
# FIXME: "" instead of None is for protobuf compatibility.
124-
self.single_col_cert_common_name = single_col_cert_common_name or ""
124+
# FIXME: using DEFAULT_CERT_COMMON_NAME for protobuf compatibility.
125+
self.single_col_cert_common_name = (
126+
single_col_cert_common_name or constants.DEFAULT_CERT_COMMON_NAME
127+
)
125128

126129
self.straggler_handling_policy = straggler_handling_policy()
127130

128131
self.rounds_to_train = rounds_to_train
129132
self.assigner = assigner
130133
if self.assigner.is_task_group_evaluation():
131-
self.rounds_to_train = 1
134+
self.rounds_to_train = constants.EVALUATION_ROUNDS
132135
logger.info(f"For evaluation tasks setting rounds_to_train = {self.rounds_to_train}")
133136

134137
self._end_of_round_check_done = [False] * rounds_to_train
@@ -143,7 +146,7 @@ def __init__(
143146

144147
self.tensor_db = TensorDB()
145148
if persist_checkpoint:
146-
persistent_db_path = persistent_db_path or "tensor.db"
149+
persistent_db_path = persistent_db_path or constants.DEFAULT_PERSISTENT_DB_PATH
147150
logger.info(
148151
"Persistent checkpoint is enabled, setting persistent db at path %s",
149152
persistent_db_path,
@@ -590,7 +593,8 @@ def get_aggregated_tensor(
590593

591594
tensor_key = TensorKey(tensor_name, self.uuid, round_number, report, tags)
592595
tensor_name, origin, round_number, report, tags = tensor_key
593-
596+
# TODO: This is a temporary fix. The tags should be updated in the
597+
# TensorDB.
594598
if "aggregated" in tags and "delta" in tags and round_number != 0:
595599
agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, ("aggregated",))
596600
else:
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Constants for Aggregator component
2+
3+
# Default number of rounds to train during federated learning
4+
DEFAULT_ROUNDS_TO_TRAIN = 256
5+
6+
# Number of rounds to train during evaluation mode
7+
EVALUATION_ROUNDS = 1
8+
9+
# Default number of rounds to store in the TensorDB
10+
DEFAULT_DB_STORE_ROUNDS = 1
11+
12+
# Default persistent database path if not provided
13+
DEFAULT_PERSISTENT_DB_PATH = "tensor.db"
14+
15+
# Default certificate common name to satisfy protobuf compatibility when none is provided
16+
DEFAULT_CERT_COMMON_NAME = ""

0 commit comments

Comments
 (0)