Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List, Optional

import openfl.callbacks as callbacks_module
from openfl.component import constants
from openfl.component.aggregator.straggler_handling import CutoffTimePolicy, StragglerPolicy
from openfl.databases import PersistentTensorDB, TensorDB
from openfl.interface.aggregation_functions import WeightedAverage
Expand Down Expand Up @@ -74,10 +75,10 @@ def __init__(
assigner,
use_delta_updates=True,
straggler_handling_policy: StragglerPolicy = CutoffTimePolicy,
rounds_to_train=256,
rounds_to_train=constants.ROUNDS_TO_TRAIN,
single_col_cert_common_name=None,
compression_pipeline=None,
db_store_rounds=1,
db_store_rounds=constants.DB_STORE_ROUNDS,
initial_tensor_dict=None,
log_memory_usage=False,
write_logs=False,
Expand All @@ -100,13 +101,13 @@ def __init__(
assigner: Assigner object.
straggler_handling_policy (optional): Straggler handling policy.
rounds_to_train (int, optional): Number of rounds to train.
Defaults to 256.
Defaults to constants.ROUNDS_TO_TRAIN.
single_col_cert_common_name (str, optional): Common name for single
collaborator certificate. Defaults to None.
compression_pipeline (optional): Compression pipeline. Defaults to
NoCompressionPipeline.
db_store_rounds (int, optional): Rounds to store in TensorDB.
Defaults to 1.
Defaults to constants.DB_STORE_ROUNDS.
initial_tensor_dict (dict, optional): Initial tensor dictionary.
callbacks: List of callbacks to be used during the experiment.
"""
Expand All @@ -120,15 +121,15 @@ def __init__(
"provide proper Public Key Infrastructure (PKI) security. "
"Please use this mode with caution."
)
# FIXME: "" instead of None is for protobuf compatibility.
self.single_col_cert_common_name = single_col_cert_common_name or ""
# FIXME: using CERT_COMMON_NAME for protobuf compatibility.
self.single_col_cert_common_name = single_col_cert_common_name or constants.CERT_COMMON_NAME

self.straggler_handling_policy = straggler_handling_policy()

self.rounds_to_train = rounds_to_train
self.assigner = assigner
if self.assigner.is_task_group_evaluation():
self.rounds_to_train = 1
self.rounds_to_train = constants.EVALUATION_ROUNDS
logger.info(f"For evaluation tasks setting rounds_to_train = {self.rounds_to_train}")

self._end_of_round_check_done = [False] * rounds_to_train
Expand All @@ -143,7 +144,7 @@ def __init__(

self.tensor_db = TensorDB()
if persist_checkpoint:
persistent_db_path = persistent_db_path or "tensor.db"
persistent_db_path = persistent_db_path or constants.PERSISTENT_DB_PATH
logger.info(
"Persistent checkpoint is enabled, setting persistent db at path %s",
persistent_db_path,
Expand Down Expand Up @@ -590,7 +591,8 @@ def get_aggregated_tensor(

tensor_key = TensorKey(tensor_name, self.uuid, round_number, report, tags)
tensor_name, origin, round_number, report, tags = tensor_key

# TODO: This is a temporary fix. The tags should be updated in the
# TensorDB.
if "aggregated" in tags and "delta" in tags and round_number != 0:
agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, ("aggregated",))
else:
Expand Down
23 changes: 10 additions & 13 deletions openfl/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List, Optional, Tuple

import openfl.callbacks as callbacks_module
from openfl.component import constants
from openfl.databases import TensorDB
from openfl.pipelines import NoCompressionPipeline, TensorCodec
from openfl.protocols import utils
Expand Down Expand Up @@ -78,11 +79,11 @@ def __init__(
client,
task_runner,
task_config,
opt_treatment="RESET",
device_assignment_policy="CPU_ONLY",
delta_updates=False,
opt_treatment=constants.OPT_TREATMENT,
device_assignment_policy=constants.DEVICE_ASSIGNMENT_POLICY,
delta_updates=constants.DELTA_UPDATES,
compression_pipeline=None,
db_store_rounds=1,
db_store_rounds=constants.DB_STORE_ROUNDS,
log_memory_usage=False,
write_logs=False,
callbacks: Optional[List] = None,
Expand All @@ -97,23 +98,19 @@ def __init__(
task_runner (object): The task runner object.
task_config (dict): The task configuration.
opt_treatment (str, optional): The optimizer state treatment.
Defaults to 'RESET'.
Defaults to constants.OPT_TREATMENT.
device_assignment_policy (str, optional): The device assignment
policy. Defaults to 'CPU_ONLY'.
policy. Defaults to constants.DEVICE_ASSIGNMENT_POLICY.
delta_updates (bool, optional): If True, only model delta gets
sent. If False, whole model gets sent to collaborator.
Defaults to False.
Defaults to constants.DELTA_UPDATES.
compression_pipeline (object, optional): The compression pipeline.
Defaults to None.
db_store_rounds (int, optional): The number of rounds to store in
the database. Defaults to 1.
the database. Defaults to constants.DB_STORE_ROUNDS.
callbacks (list, optional): List of callbacks. Defaults to None.
"""
self.single_col_cert_common_name = None

if self.single_col_cert_common_name is None:
self.single_col_cert_common_name = "" # for protobuf compatibility
# we would really want this as an object
self.single_col_cert_common_name = constants.CERT_COMMON_NAME # for protobuf compatibility

self.collaborator_name = collaborator_name
self.aggregator_uuid = aggregator_uuid
Expand Down
10 changes: 10 additions & 0 deletions openfl/component/constants.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would suggest renaming it to defaults.py.
Using defaults.ROUNDS_TO_TRAIN increases readability compared to constants.ROUNDS_TO_TRAIN in my opinion.

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Centralized constants for openfl/component modules

ROUNDS_TO_TRAIN = 256
EVALUATION_ROUNDS = 1
DB_STORE_ROUNDS = 1
PERSISTENT_DB_PATH = "tensor.db"
CERT_COMMON_NAME = ""
OPT_TREATMENT = "RESET"
DEVICE_ASSIGNMENT_POLICY = "CPU_ONLY"
DELTA_UPDATES = False
Loading