1010from typing import List , Optional
1111
1212import openfl .callbacks as callbacks_module
13+ from openfl .component .aggregator import constants
1314from openfl .component .aggregator .straggler_handling import CutoffTimePolicy , StragglerPolicy
1415from openfl .databases import PersistentTensorDB , TensorDB
1516from openfl .interface .aggregation_functions import WeightedAverage
@@ -74,10 +75,10 @@ def __init__(
7475 assigner ,
7576 use_delta_updates = True ,
7677 straggler_handling_policy : StragglerPolicy = CutoffTimePolicy ,
77- rounds_to_train = 256 ,
78+ rounds_to_train = constants . DEFAULT_ROUNDS_TO_TRAIN ,
7879 single_col_cert_common_name = None ,
7980 compression_pipeline = None ,
80- db_store_rounds = 1 ,
81+ db_store_rounds = constants . DEFAULT_DB_STORE_ROUNDS ,
8182 initial_tensor_dict = None ,
8283 log_memory_usage = False ,
8384 write_logs = False ,
@@ -100,13 +101,13 @@ def __init__(
100101 assigner: Assigner object.
101102 straggler_handling_policy (optional): Straggler handling policy.
102103 rounds_to_train (int, optional): Number of rounds to train.
103- Defaults to 256 .
104+ Defaults to constants.DEFAULT_ROUNDS_TO_TRAIN .
104105 single_col_cert_common_name (str, optional): Common name for single
105106 collaborator certificate. Defaults to None.
106107 compression_pipeline (optional): Compression pipeline. Defaults to
107108 NoCompressionPipeline.
108109 db_store_rounds (int, optional): Rounds to store in TensorDB.
109- Defaults to 1 .
110+ Defaults to constants.DEFAULT_DB_STORE_ROUNDS .
110111 initial_tensor_dict (dict, optional): Initial tensor dictionary.
111112 callbacks: List of callbacks to be used during the experiment.
112113 """
@@ -120,15 +121,17 @@ def __init__(
120121 "provide proper Public Key Infrastructure (PKI) security. "
121122 "Please use this mode with caution."
122123 )
123- # FIXME: "" instead of None is for protobuf compatibility.
124- self .single_col_cert_common_name = single_col_cert_common_name or ""
124+ # FIXME: using DEFAULT_CERT_COMMON_NAME for protobuf compatibility.
125+ self .single_col_cert_common_name = (
126+ single_col_cert_common_name or constants .DEFAULT_CERT_COMMON_NAME
127+ )
125128
126129 self .straggler_handling_policy = straggler_handling_policy ()
127130
128131 self .rounds_to_train = rounds_to_train
129132 self .assigner = assigner
130133 if self .assigner .is_task_group_evaluation ():
131- self .rounds_to_train = 1
134+ self .rounds_to_train = constants . EVALUATION_ROUNDS
132135 logger .info (f"For evaluation tasks setting rounds_to_train = { self .rounds_to_train } " )
133136
134137 self ._end_of_round_check_done = [False ] * rounds_to_train
@@ -143,7 +146,7 @@ def __init__(
143146
144147 self .tensor_db = TensorDB ()
145148 if persist_checkpoint :
146- persistent_db_path = persistent_db_path or "tensor.db"
149+ persistent_db_path = persistent_db_path or constants . DEFAULT_PERSISTENT_DB_PATH
147150 logger .info (
148151 "Persistent checkpoint is enabled, setting persistent db at path %s" ,
149152 persistent_db_path ,
@@ -590,7 +593,8 @@ def get_aggregated_tensor(
590593
591594 tensor_key = TensorKey (tensor_name , self .uuid , round_number , report , tags )
592595 tensor_name , origin , round_number , report , tags = tensor_key
593-
596+ # TODO: This is a temporary fix. The tags should be updated in the
597+ # TensorDB.
594598 if "aggregated" in tags and "delta" in tags and round_number != 0 :
595599 agg_tensor_key = TensorKey (tensor_name , origin , round_number , report , ("aggregated" ,))
596600 else :
0 commit comments