Skip to content

Commit

Permalink
Add a utility function which generates peer fl context in a more effi…
Browse files Browse the repository at this point in the history
…cient way

Update a few locations with the above utility function
  • Loading branch information
IsaacYangSLA committed Sep 29, 2023
1 parent f6eea0d commit 6aebf9c
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 18 deletions.
11 changes: 11 additions & 0 deletions nvflare/apis/utils/fl_context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging

from nvflare.apis.fl_constant import FLContextKey, NonSerializableKeys
Expand All @@ -38,6 +39,16 @@ def get_serializable_data(fl_ctx: FLContext):
return new_fl_ctx


def gen_new_peer_ctx(fl_ctx: FLContext, need_deep_copy=False):
tmp_ctx = FLContext()
pub_props = fl_ctx.get_all_public_props()
if need_deep_copy:
pub_props = copy.deepcopy(pub_props)
tmp_ctx.set_public_props(pub_props)
new_peer_ctx = get_serializable_data(tmp_ctx)
return new_peer_ctx


def generate_log_message(fl_ctx: FLContext, msg: str):
_identity_ = "identity"
_my_run = "run"
Expand Down
6 changes: 2 additions & 4 deletions nvflare/private/fed/client/command_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging

from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.utils.fl_context_utils import get_serializable_data
from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx
from nvflare.fuel.f3.cellnet.cell import Message as CellMessage
from nvflare.fuel.f3.cellnet.cell import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.cell import make_reply as make_cellnet_reply
Expand Down Expand Up @@ -85,8 +84,7 @@ def aux_communication(self, request: CellMessage) -> CellMessage:
topic = request.get_header(MessageHeaderKey.TOPIC)
reply = self.engine.dispatch(topic=topic, request=shareable, fl_ctx=fl_ctx)

shared_fl_ctx = FLContext()
shared_fl_ctx.set_public_props(copy.deepcopy(get_serializable_data(fl_ctx).get_all_public_props()))
shared_fl_ctx = gen_new_peer_ctx(fl_ctx)
reply.set_header(key=FLContextKey.PEER_CONTEXT, value=shared_fl_ctx)

if reply is not None:
Expand Down
8 changes: 3 additions & 5 deletions nvflare/private/fed/client/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_exception import FLCommunicationError
from nvflare.apis.shareable import Shareable
from nvflare.apis.utils.fl_context_utils import get_serializable_data
from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx
from nvflare.fuel.f3.cellnet.cell import FQCN, Cell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.utils import fobs
Expand Down Expand Up @@ -164,8 +164,7 @@ def pull_task(self, servers, project_name, token, ssid, fl_ctx: FLContext):
"""
start_time = time.time()
shareable = Shareable()
shared_fl_ctx = FLContext()
shared_fl_ctx.set_public_props(get_serializable_data(fl_ctx).get_all_public_props())
shared_fl_ctx = gen_new_peer_ctx(fl_ctx)
shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx)
client_name = fl_ctx.get_identity_name()
task_message = new_cell_message(
Expand Down Expand Up @@ -230,8 +229,7 @@ def submit_update(
ReturnCode
"""
start_time = time.time()
shared_fl_ctx = FLContext()
shared_fl_ctx.set_public_props(get_serializable_data(fl_ctx).get_all_public_props())
shared_fl_ctx = gen_new_peer_ctx(fl_ctx)
shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx)

# shareable.add_cookie(name=FLContextKey.TASK_ID, data=task_id)
Expand Down
7 changes: 2 additions & 5 deletions nvflare/private/fed/server/server_command_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging

from nvflare.apis.fl_constant import FLContextKey, ServerCommandKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.utils.fl_context_utils import get_serializable_data
from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx
from nvflare.fuel.f3.cellnet.cell import Cell, MessageHeaderKey, ReturnCode, make_reply
from nvflare.fuel.f3.message import Message as CellMessage
from nvflare.fuel.utils import fobs
Expand Down Expand Up @@ -116,8 +114,7 @@ def aux_communicate(self, request: CellMessage) -> CellMessage:
engine = fl_ctx.get_engine()
reply = engine.dispatch(topic=topic, request=data, fl_ctx=fl_ctx)

shared_fl_ctx = FLContext()
shared_fl_ctx.set_public_props(copy.deepcopy(get_serializable_data(fl_ctx).get_all_public_props()))
shared_fl_ctx = gen_new_peer_ctx(fl_ctx)
reply.set_header(key=FLContextKey.PEER_CONTEXT, value=shared_fl_ctx)

if reply is not None:
Expand Down
8 changes: 4 additions & 4 deletions nvflare/private/fed/server/server_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""FL Admin commands."""

import copy
import logging
import time
from abc import ABC, abstractmethod
Expand All @@ -29,7 +28,7 @@
)
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.utils.fl_context_utils import get_serializable_data
from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx
from nvflare.private.defs import SpecialTaskName, TaskConstant
from nvflare.widgets.widget import WidgetID

Expand Down Expand Up @@ -180,8 +179,9 @@ def process(self, data: Shareable, fl_ctx: FLContext):
shareable.set_header(key=FLContextKey.TASK_ID, value=task_id)

shareable.set_header(key=ServerCommandKey.TASK_NAME, value=taskname)
shared_fl_ctx = FLContext()
shared_fl_ctx.set_public_props(copy.deepcopy(get_serializable_data(fl_ctx).get_all_public_props()))

# TODO
shared_fl_ctx = gen_new_peer_ctx(fl_ctx)
shareable.set_header(key=FLContextKey.PEER_CONTEXT, value=shared_fl_ctx)

if taskname != SpecialTaskName.TRY_AGAIN:
Expand Down

0 comments on commit 6aebf9c

Please sign in to comment.