diff --git a/nvflare/fuel/f3/cellnet/cell.py b/nvflare/fuel/f3/cellnet/cell.py index 575d4b21a7..80ddb13be9 100644 --- a/nvflare/fuel/f3/cellnet/cell.py +++ b/nvflare/fuel/f3/cellnet/cell.py @@ -37,7 +37,8 @@ ServiceUnavailable, ) from nvflare.fuel.f3.cellnet.fqcn import FQCN, FqcnInfo, same_family -from nvflare.fuel.f3.cellnet.utils import decode_payload, encode_payload, format_log_message, make_reply, new_message +from nvflare.fuel.f3.cellnet.registry import Callback, Registry +from nvflare.fuel.f3.cellnet.utils import decode_payload, encode_payload, format_log_message, make_reply from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.f3.communicator import Communicator, MessageReceiver from nvflare.fuel.f3.connection import Connection @@ -86,7 +87,7 @@ def to_dict(self): @staticmethod def from_dict(d: dict): msg_dict = d.get("message") - msg = new_message(headers=msg_dict.get("headers"), payload=msg_dict.get("payload")) + msg = Message(headers=msg_dict.get("headers"), payload=msg_dict.get("payload")) return TargetMessage(target=d.get("target"), channel=d.get("channel"), topic=d.get("topic"), message=msg) @@ -110,46 +111,6 @@ def get_fqcn(self): return self.info.fqcn -class _CB: - def __init__(self, cb, args, kwargs): - self.cb = cb - self.args = args - self.kwargs = kwargs - - -class _Registry: - def __init__(self): - self.reg = {} # channel/topic => _CB - - @staticmethod - def _item_key(channel: str, topic: str) -> str: - return f"{channel}:{topic}" - - def set(self, channel: str, topic: str, items): - key = self._item_key(channel, topic) - self.reg[key] = items - - def append(self, channel: str, topic: str, items): - key = self._item_key(channel, topic) - item_list = self.reg.get(key) - if not item_list: - item_list = [] - self.reg[key] = item_list - item_list.append(items) - - def find(self, channel: str, topic: str): - items = self.reg.get(self._item_key(channel, topic)) - if not items: - # try topic * in channel - items = self.reg.get(self._item_key(channel, "*")) - - if not items: - # try topic * in channel * - items = self.reg.get(self._item_key("*", "*")) - - return items - - class _Waiter(threading.Event): def __init__(self, targets: List[str]): super().__init__() @@ -215,7 +176,7 @@ def send(self): f"{self.cell.get_fqcn()}: bulk sender {self.target} sending bulk size {len(messages_to_send)}" ) tms = [m.to_dict() for m in messages_to_send] - bulk_msg = new_message(payload=tms) + bulk_msg = Message(None, tms) send_errs = self.cell.fire_and_forget( channel=_CHANNEL, topic=_TOPIC_BULK, targets=[self.target], message=bulk_msg ) @@ -380,12 +341,12 @@ def __init__( self.communicator.register_message_receiver(app_id=self.APP_ID, receiver=self) self.communicator.register_monitor(monitor=self) - self.req_reg = _Registry() - self.in_req_filter_reg = _Registry() # for request received - self.out_reply_filter_reg = _Registry() # for reply going out - self.out_req_filter_reg = _Registry() # for request sent - self.in_reply_filter_reg = _Registry() # for reply received - self.error_handler_reg = _Registry() + self.req_reg = Registry() + self.in_req_filter_reg = Registry() # for request received + self.out_reply_filter_reg = Registry() # for reply going out + self.out_req_filter_reg = Registry() # for request sent + self.in_reply_filter_reg = Registry() # for reply received + self.error_handler_reg = Registry() self.cell_connected_cb = None self.cell_connected_cb_args = None self.cell_connected_cb_kwargs = None @@ -866,7 +827,7 @@ def stop(self): channel=_CHANNEL, topic=_TOPIC_BYE, targets=targets, - request=new_message(), + request=Message(), timeout=0.5, optional=True, ) @@ -905,39 +866,39 @@ def register_request_cb(self, channel: str, topic: str, cb, *args, **kwargs): """ if not callable(cb): raise ValueError(f"specified request_cb {type(cb)} is not callable") - self.req_reg.set(channel, topic, _CB(cb, args, kwargs)) + self.req_reg.set(channel, topic, Callback(cb, args, kwargs)) def add_incoming_request_filter(self, channel: str, topic: str, cb, *args, **kwargs): if not callable(cb): raise ValueError(f"specified incoming_request_filter {type(cb)} is not callable") - self.in_req_filter_reg.append(channel, topic, _CB(cb, args, kwargs)) + self.in_req_filter_reg.append(channel, topic, Callback(cb, args, kwargs)) def add_outgoing_reply_filter(self, channel: str, topic: str, cb, *args, **kwargs): if not callable(cb): raise ValueError(f"specified outgoing_reply_filter {type(cb)} is not callable") - self.out_reply_filter_reg.append(channel, topic, _CB(cb, args, kwargs)) + self.out_reply_filter_reg.append(channel, topic, Callback(cb, args, kwargs)) def add_outgoing_request_filter(self, channel: str, topic: str, cb, *args, **kwargs): if not callable(cb): raise ValueError(f"specified outgoing_request_filter {type(cb)} is not callable") - self.out_req_filter_reg.append(channel, topic, _CB(cb, args, kwargs)) + self.out_req_filter_reg.append(channel, topic, Callback(cb, args, kwargs)) def add_incoming_reply_filter(self, channel: str, topic: str, cb, *args, **kwargs): if not callable(cb): raise ValueError(f"specified incoming_reply_filter {type(cb)} is not callable") - self.in_reply_filter_reg.append(channel, topic, _CB(cb, args, kwargs)) + self.in_reply_filter_reg.append(channel, topic, Callback(cb, args, kwargs)) def add_error_handler(self, channel: str, topic: str, cb, *args, **kwargs): if not callable(cb): raise ValueError(f"specified error_handler {type(cb)} is not callable") - self.error_handler_reg.set(channel, topic, _CB(cb, args, kwargs)) + self.error_handler_reg.set(channel, topic, Callback(cb, args, kwargs)) def _filter_outgoing_request(self, channel: str, topic: str, request: Message) -> Union[None, Message]: cbs = self.out_req_filter_reg.find(channel, topic) if not cbs: return None for _cb in cbs: - assert isinstance(_cb, _CB) + assert isinstance(_cb, Callback) reply = self._try_cb(request, _cb.cb, *_cb.args, **_cb.kwargs) if reply: return reply @@ -1112,7 +1073,7 @@ def _send_target_messages( self.logger.debug(f"{self.my_info.fqcn}: invoking outgoing request filters") assert isinstance(req_filters, list) for f in req_filters: - assert isinstance(f, _CB) + assert isinstance(f, Callback) r = self._try_cb(req, f.cb, *f.args, **f.kwargs) if r: send_errs[t] = ReturnCode.FILTER_ERROR @@ -1343,7 +1304,7 @@ def _peer_goodbye(self, request: Message): self.logger.debug(f"{self.my_info.fqcn}: agent for {peer_ep.name} is already gone") # ack back - return new_message() + return Message() def _receive_bulk_message(self, request: Message): target_msgs = request.payload @@ -1477,12 +1438,12 @@ def _process_request(self, origin: str, message: Message) -> Union[None, Message self.logger.debug(f"{self.my_info.fqcn}: invoking incoming request filters") assert isinstance(req_filters, list) for f in req_filters: - assert isinstance(f, _CB) + assert isinstance(f, Callback) reply = self._try_cb(message, f.cb, *f.args, **f.kwargs) if reply: return reply - assert isinstance(_cb, _CB) + assert isinstance(_cb, Callback) self.logger.debug(f"{self.my_info.fqcn}: calling registered request CB") cb_start = time.perf_counter() reply = self._try_cb(message, _cb.cb, *_cb.args, **_cb.kwargs) @@ -1490,7 +1451,6 @@ def _process_request(self, origin: str, message: Message) -> Union[None, Message self.req_cb_stats_pool.record_value(category=self._stats_category(message), value=cb_end - cb_start) if not reply: # the CB doesn't have anything to reply - self.logger.debug("no reply is returned from the CB") return None if not isinstance(reply, Message): @@ -1630,7 +1590,7 @@ def _process_reply(self, origin: str, message: Message, msg_type: str): self.logger.debug(f"{self.my_info.fqcn}: invoking incoming reply filters") assert isinstance(reply_filters, list) for f in reply_filters: - assert isinstance(f, _CB) + assert isinstance(f, Callback) self._try_cb(message, f.cb, *f.args, **f.kwargs) for rid in req_ids: @@ -1808,7 +1768,6 @@ def _process_received_msg(self, endpoint: Endpoint, connection: Connection, mess reply = self._process_request(origin, message) if not reply: - self.logger.debug(f"{self.my_info.fqcn}: don't send response - nothing to send") self.received_msg_counter_pool.increment( category=self._stats_category(message), counter_name=_CounterName.REPLY_NONE ) @@ -1863,7 +1822,7 @@ def _process_received_msg(self, endpoint: Endpoint, connection: Connection, mess self.logger.debug(f"{self.my_info.fqcn}: invoking outgoing reply filters") assert isinstance(reply_filters, list) for f in reply_filters: - assert isinstance(f, _CB) + assert isinstance(f, Callback) r = self._try_cb(reply, f.cb, *f.args, **f.kwargs) if r: reply = r diff --git a/nvflare/fuel/f3/cellnet/net_agent.py b/nvflare/fuel/f3/cellnet/net_agent.py index 0767717ae9..3c033947b3 100644 --- a/nvflare/fuel/f3/cellnet/net_agent.py +++ b/nvflare/fuel/f3/cellnet/net_agent.py @@ -27,7 +27,7 @@ from nvflare.fuel.f3.cellnet.connector_manager import ConnectorData from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode from nvflare.fuel.f3.cellnet.fqcn import FQCN -from nvflare.fuel.f3.cellnet.utils import make_reply, new_message +from nvflare.fuel.f3.cellnet.utils import make_reply from nvflare.fuel.f3.stats_pool import StatsPoolManager from nvflare.fuel.utils.config_service import ConfigService @@ -277,7 +277,7 @@ def stop_subnet(self, monitor: SubnetMonitor): cells_to_stop.append(member_fqcn) if cells_to_stop: return self.cell.broadcast_request( - channel=_CHANNEL, topic=_TOPIC_STOP_CELL, request=new_message(), targets=cells_to_stop, timeout=1.0 + channel=_CHANNEL, topic=_TOPIC_STOP_CELL, request=Message(), targets=cells_to_stop, timeout=1.0 ) else: return None @@ -310,7 +310,7 @@ def _subnet_heartbeat(self): channel=_CHANNEL, topic=_TOPIC_HEARTBEAT, targets=target, - message=new_message(payload={"subnet_id": subnet_id}), + message=Message(payload={"subnet_id": subnet_id}), ) # wait for interval time, but watch for "asked_to_stop" every 0.1 secs @@ -363,10 +363,10 @@ def _do_stop(self, request: Message) -> Union[None, Message]: def _do_stop_cell(self, request: Message) -> Union[None, Message]: self.stop() - return new_message() + return Message() def _do_route(self, request: Message) -> Union[None, Message]: - return new_message(payload=dict(request.headers)) + return Message(payload=dict(request.headers)) def _do_start_route(self, request: Message) -> Union[None, Message]: target_fqcn = request.payload @@ -375,14 +375,14 @@ def _do_start_route(self, request: Message) -> Union[None, Message]: return make_reply(ReturnCode.PROCESS_EXCEPTION, f"bad target fqcn {err}") assert isinstance(target_fqcn, str) reply_headers, req_headers = self.get_route_info(target_fqcn) - return new_message(payload={"request": dict(req_headers), "reply": dict(reply_headers)}) + return Message(payload={"request": dict(req_headers), "reply": dict(reply_headers)}) def _do_peers(self, request: Message) -> Union[None, Message]: - return new_message(payload=list(self.cell.agents.keys())) + return Message(payload=list(self.cell.agents.keys())) def get_peers(self, target_fqcn: str) -> (Union[None, dict], List[str]): reply = self.cell.send_request( - channel=_CHANNEL, topic=_TOPIC_PEERS, target=target_fqcn, timeout=1.0, request=new_message() + channel=_CHANNEL, topic=_TOPIC_PEERS, target=target_fqcn, timeout=1.0, request=Message() ) err = "" @@ -425,11 +425,11 @@ def _get_connectors(self) -> dict: return result def _do_connectors(self, request: Message) -> Union[None, Message]: - return new_message(payload=self._get_connectors()) + return Message(payload=self._get_connectors()) def get_connectors(self, target_fqcn: str) -> (dict, dict): reply = self.cell.send_request( - channel=_CHANNEL, topic=_TOPIC_CONNS, target=target_fqcn, timeout=1.0, request=new_message() + channel=_CHANNEL, topic=_TOPIC_CONNS, target=target_fqcn, timeout=1.0, request=Message() ) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc == ReturnCode.OK: @@ -481,7 +481,7 @@ def _get_url_use_of_cell(self, url: str): def get_url_use(self, url) -> dict: result = {self.cell.get_fqcn(): self._get_url_use_of_cell(url)} - replies = self._broadcast_to_subs(topic=_TOPIC_URL_USE, message=new_message(payload=url)) + replies = self._broadcast_to_subs(topic=_TOPIC_URL_USE, message=Message(payload=url)) for t, r in replies.items(): assert isinstance(r, Message) rc = r.get_header(MessageHeaderKey.RETURN_CODE) @@ -496,11 +496,11 @@ def get_url_use(self, url) -> dict: def _do_url_use(self, request: Message) -> Union[None, Message]: results = self.get_url_use(request.payload) - return new_message(payload=results) + return Message(payload=results) def get_route_info(self, target_fqcn: str) -> (dict, dict): reply = self.cell.send_request( - channel=_CHANNEL, topic=_TOPIC_ROUTE, target=target_fqcn, timeout=1.0, request=new_message() + channel=_CHANNEL, topic=_TOPIC_ROUTE, target=target_fqcn, timeout=1.0, request=Message() ) reply_headers = reply.headers rc = reply.get_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK) @@ -520,7 +520,7 @@ def start_route(self, from_fqcn: str, target_fqcn: str) -> (str, dict, dict): topic=_TOPIC_START_ROUTE, target=from_fqcn, timeout=1.0, - request=new_message(payload=target_fqcn), + request=Message(payload=target_fqcn), ) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc == ReturnCode.OK: @@ -537,7 +537,7 @@ def start_route(self, from_fqcn: str, target_fqcn: str) -> (str, dict, dict): def _do_report_cells(self, request: Message) -> Union[None, Message]: _, results = self.request_cells_info() - return new_message(payload=results) + return Message(payload=results) def stop(self): # ask all children to stop @@ -549,7 +549,7 @@ def stop_cell(self, target: str) -> str: # self.stop() # return ReturnCode.OK reply = self.cell.send_request( - channel=_CHANNEL, topic=_TOPIC_STOP_CELL, request=new_message(), target=target, timeout=1.0 + channel=_CHANNEL, topic=_TOPIC_STOP_CELL, request=Message(), target=target, timeout=1.0 ) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) return rc @@ -573,7 +573,7 @@ def _request_speed_test(self, target_fqcn: str, num, size) -> Message: channel=_CHANNEL, topic=_TOPIC_ECHO, target=target_fqcn, - request=new_message(payload=payload), + request=Message(payload=payload), timeout=10.0, ) rc = r.get_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK) @@ -600,7 +600,7 @@ def _request_speed_test(self, target_fqcn: str, num, size) -> Message: end = time.perf_counter() total = end - start avg = total / num - return new_message( + return Message( payload={ "test": f"{size:,}KB {num} rounds between {self.cell.get_fqcn()} and {target_fqcn}", "prep": payload_prep_time, @@ -633,7 +633,7 @@ def _do_speed(self, request: Message) -> Union[None, Message]: return self._request_speed_test(to_fqcn, num, size) def _do_echo(self, request: Message) -> Union[None, Message]: - return new_message(payload=request.payload) + return Message(payload=request.payload) def _do_stress_test(self, params): if not isinstance(params, dict): @@ -660,7 +660,7 @@ def _do_stress_test(self, params): h = hashlib.md5(payload) d1 = h.digest() target = targets[random.randrange(len(targets))] - req = new_message(payload=payload) + req = Message(payload=payload) reply = self.cell.send_request(channel=_CHANNEL, topic=_TOPIC_ECHO, target=target, request=req, timeout=1.0) if target not in counts: counts[target] = 0 @@ -685,7 +685,7 @@ def _do_stress_test(self, params): def _do_stress(self, request: Message) -> Union[None, Message]: params = request.payload result = self._do_stress_test(params) - return new_message(payload=result) + return Message(payload=result) def start_stress_test(self, targets: list, num_rounds=10, timeout=5.0): self.cell.logger.info(f"{self.cell.get_fqcn()}: starting stress test on {targets}") @@ -702,7 +702,7 @@ def start_stress_test(self, targets: list, num_rounds=10, timeout=5.0): channel=_CHANNEL, topic=_TOPIC_STRESS, targets=msg_targets, - request=new_message(payload=payload), + request=Message(payload=payload), timeout=timeout, ) for t, r in replies.items(): @@ -728,7 +728,7 @@ def speed_test(self, from_fqcn: str, to_fqcn: str, num_tries, payload_size) -> d reply = self.cell.send_request( channel=_CHANNEL, topic=_TOPIC_SPEED, - request=new_message(payload={"to": to_fqcn, "num": num_tries, "size": payload_size}), + request=Message(payload={"to": to_fqcn, "num": num_tries, "size": payload_size}), target=from_fqcn, timeout=100.0, ) @@ -744,7 +744,7 @@ def speed_test(self, from_fqcn: str, to_fqcn: str, num_tries, payload_size) -> d return result def change_root(self, new_root_url: str): - self._broadcast_to_subs(topic=_TOPIC_CHANGE_ROOT, message=new_message(payload=new_root_url), timeout=0.0) + self._broadcast_to_subs(topic=_TOPIC_CHANGE_ROOT, message=Message(payload=new_root_url), timeout=0.0) def _do_change_root(self, request: Message) -> Union[None, Message]: new_root_url = request.payload @@ -768,7 +768,7 @@ def start_bulk_test(self, targets: list, size: int): channel=_CHANNEL, topic=_TOPIC_BULK_TEST, targets=msg_targets, - request=new_message(payload=size), + request=Message(payload=size), timeout=1.0, ) for t, r in replies.items(): @@ -786,14 +786,14 @@ def _do_bulk_test(self, request: Message) -> Union[None, Message]: for _ in range(size): num = random.randint(0, 100) nums.append(num) - msg = new_message(payload=num) + msg = Message(payload=num) self.cell.queue_message( channel=_CHANNEL, topic=_TOPIC_BULK_ITEM, targets=FQCN.ROOT_SERVER, message=msg, ) - return new_message(payload=f"queued: {nums}") + return Message(payload=f"queued: {nums}") def _do_bulk_item(self, request: Message) -> Union[None, Message]: num = request.payload @@ -805,7 +805,7 @@ def get_msg_stats_table(self, target: str, mode: str): reply = self.cell.send_request( channel=_CHANNEL, topic=_TOPIC_MSG_STATS, - request=new_message(payload={"mode": mode}), + request=Message(payload={"mode": mode}), timeout=1.0, target=target, ) @@ -820,11 +820,11 @@ def _do_msg_stats(self, request: Message) -> Union[None, Message]: mode = p.get("mode") headers, rows = self.cell.msg_stats_pool.get_table(mode) reply = {"headers": headers, "rows": rows} - return new_message(payload=reply) + return Message(payload=reply) def get_pool_list(self, target: str): reply = self.cell.send_request( - channel=_CHANNEL, topic=_TOPIC_LIST_POOLS, request=new_message(), timeout=1.0, target=target + channel=_CHANNEL, topic=_TOPIC_LIST_POOLS, request=Message(), timeout=1.0, target=target ) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) err = reply.get_header(MessageHeaderKey.ERROR, "") @@ -835,13 +835,13 @@ def get_pool_list(self, target: str): def _do_list_pools(self, request: Message) -> Union[None, Message]: headers, rows = StatsPoolManager.get_table() reply = {"headers": headers, "rows": rows} - return new_message(payload=reply) + return Message(payload=reply) def show_pool(self, target: str, pool_name: str, mode: str): reply = self.cell.send_request( channel=_CHANNEL, topic=_TOPIC_SHOW_POOL, - request=new_message(payload={"mode": mode, "pool": pool_name}), + request=Message(payload={"mode": mode, "pool": pool_name}), timeout=1.0, target=target, ) @@ -858,7 +858,7 @@ def _do_show_pool(self, request: Message) -> Union[None, Message]: mode = p.get("mode", "") pool = StatsPoolManager.get_pool(pool_name) if not pool: - return new_message( + return Message( headers={ MessageHeaderKey.RETURN_CODE: ReturnCode.INVALID_REQUEST, MessageHeaderKey.ERROR: f"unknown pool '{pool_name}'", @@ -866,11 +866,11 @@ def _do_show_pool(self, request: Message) -> Union[None, Message]: ) headers, rows = pool.get_table(mode) reply = {"headers": headers, "rows": rows} - return new_message(payload=reply) + return Message(payload=reply) def get_comm_config(self, target: str): reply = self.cell.send_request( - channel=_CHANNEL, topic=_TOPIC_COMM_CONFIG, request=new_message(), timeout=1.0, target=target + channel=_CHANNEL, topic=_TOPIC_COMM_CONFIG, request=Message(), timeout=1.0, target=target ) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc != ReturnCode.OK: @@ -880,7 +880,7 @@ def get_comm_config(self, target: str): def get_config_vars(self, target: str): reply = self.cell.send_request( - channel=_CHANNEL, topic=_TOPIC_CONFIG_VARS, request=new_message(), timeout=1.0, target=target + channel=_CHANNEL, topic=_TOPIC_CONFIG_VARS, request=Message(), timeout=1.0, target=target ) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc != ReturnCode.OK: @@ -890,7 +890,7 @@ def get_config_vars(self, target: str): def get_process_info(self, target: str): reply = self.cell.send_request( - channel=_CHANNEL, topic=_TOPIC_PROCESS_INFO, request=new_message(), timeout=1.0, target=target + channel=_CHANNEL, topic=_TOPIC_PROCESS_INFO, request=Message(), timeout=1.0, target=target ) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc != ReturnCode.OK: @@ -900,11 +900,11 @@ def get_process_info(self, target: str): def _do_comm_config(self, request: Message) -> Union[None, Message]: info = self.cell.connector_manager.get_config_info() - return new_message(payload=info) + return Message(payload=info) def _do_config_vars(self, request: Message) -> Union[None, Message]: info = ConfigService.get_var_values() - return new_message(payload=info) + return Message(payload=info) def _do_process_info(self, request: Message) -> Union[None, Message]: @@ -918,11 +918,11 @@ def _do_process_info(self, request: Message) -> Union[None, Message]: for thread in threading.enumerate(): rows.append([f"Thread:{thread.ident}", thread.name]) - return new_message(payload={"headers": ["Resource", "Value"], "rows": rows}) + return Message(payload={"headers": ["Resource", "Value"], "rows": rows}) def _broadcast_to_subs(self, topic: str, message=None, timeout=1.0): if not message: - message = new_message() + message = Message() children, clients = self.cell.get_sub_cell_names() targets = [] diff --git a/nvflare/fuel/f3/cellnet/registry.py b/nvflare/fuel/f3/cellnet/registry.py new file mode 100644 index 0000000000..5d1bf1ff02 --- /dev/null +++ b/nvflare/fuel/f3/cellnet/registry.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + + +class Callback: + def __init__(self, cb, args, kwargs): + self.cb = cb + self.args = args + self.kwargs = kwargs + + +class Registry: + def __init__(self): + self.reg = {} # channel/topic => _CB + + @staticmethod + def _item_key(channel: str, topic: str) -> str: + return f"{channel}:{topic}" + + def set(self, channel: str, topic: str, items: Any): + key = self._item_key(channel, topic) + self.reg[key] = items + + def append(self, channel: str, topic: str, items: Any): + key = self._item_key(channel, topic) + item_list = self.reg.get(key) + if not item_list: + item_list = [] + self.reg[key] = item_list + item_list.append(items) + + def find(self, channel: str, topic: str) -> Any: + items = self.reg.get(self._item_key(channel, topic)) + if not items: + # try topic * in channel + items = self.reg.get(self._item_key(channel, "*")) + + if not items: + # try topic * in channel * + items = self.reg.get(self._item_key("*", "*")) + + return items diff --git a/nvflare/fuel/f3/cellnet/utils.py b/nvflare/fuel/f3/cellnet/utils.py index 95faa95a1f..03d38cb841 100644 --- a/nvflare/fuel/f3/cellnet/utils.py +++ b/nvflare/fuel/f3/cellnet/utils.py @@ -13,24 +13,16 @@ # limitations under the License. import nvflare.fuel.utils.fobs as fobs from nvflare.fuel.f3.cellnet.defs import Encoding, MessageHeaderKey -from nvflare.fuel.f3.message import Headers, Message +from nvflare.fuel.f3.message import Message def make_reply(rc: str, error: str = "", body=None) -> Message: - headers = Headers() - headers[MessageHeaderKey.RETURN_CODE] = rc + headers = {MessageHeaderKey.RETURN_CODE: rc} if error: headers[MessageHeaderKey.ERROR] = error return Message(headers, payload=body) -def new_message(headers: dict = None, payload=None): - msg_headers = Headers() - if headers: - msg_headers.update(headers) - return Message(msg_headers, payload) - - def format_log_message(fqcn: str, message: Message, log: str) -> str: parts = [ "[ME=" + fqcn, @@ -50,7 +42,7 @@ def encode_payload(message: Message): if not encoding: if message.payload is None: encoding = Encoding.NONE - elif isinstance(message.payload, bytes) or isinstance(message.payload, bytearray): + elif isinstance(message.payload, (bytes, bytearray, memoryview)): encoding = Encoding.BYTES else: encoding = Encoding.FOBS diff --git a/nvflare/fuel/f3/communicator.py b/nvflare/fuel/f3/communicator.py index d2854bf38d..4ab4463b86 100644 --- a/nvflare/fuel/f3/communicator.py +++ b/nvflare/fuel/f3/communicator.py @@ -260,7 +260,9 @@ def remove_connector(self, handle: str): def _exit_func(): - for c in _running_instances: + while _running_instances: + c = next(iter(_running_instances)) + # This call will remove the entry from the set c.stop() log.debug(f"Communicator {c.local_endpoint.name} was left running, stopped on exit") diff --git a/nvflare/fuel/f3/message.py b/nvflare/fuel/f3/message.py index 449032b89b..cb18e98089 100644 --- a/nvflare/fuel/f3/message.py +++ b/nvflare/fuel/f3/message.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Optional from nvflare.fuel.f3.connection import Connection from nvflare.fuel.f3.endpoint import Endpoint @@ -27,17 +27,8 @@ class AppIds: PUB_SUB = 3 -class Headers(dict): - - # Reserved Keys - MSG_ID = "_MSG_ID_" - TOPIC = "_TOPIC_" - DEST = "_DEST_" - JOB_ID = "_JOB_ID_" - - class Message: - def __init__(self, headers: Headers, payload: Any): + def __init__(self, headers: Optional[dict] = None, payload: Any = None): """Construct an FCI message""" self.headers = headers diff --git a/nvflare/fuel/f3/sfm/conn_manager.py b/nvflare/fuel/f3/sfm/conn_manager.py index 4386cfc2b9..149507fb95 100644 --- a/nvflare/fuel/f3/sfm/conn_manager.py +++ b/nvflare/fuel/f3/sfm/conn_manager.py @@ -26,7 +26,7 @@ from nvflare.fuel.f3.drivers.driver_params import DriverCap, DriverParams from nvflare.fuel.f3.drivers.net_utils import ssl_required from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor, EndpointState -from nvflare.fuel.f3.message import Headers, Message, MessageReceiver +from nvflare.fuel.f3.message import Message, MessageReceiver from nvflare.fuel.f3.sfm.constants import HandshakeKeys, Types from nvflare.fuel.f3.sfm.prefix import PREFIX_LEN, Prefix from nvflare.fuel.f3.sfm.sfm_conn import SfmConnection @@ -177,7 +177,7 @@ def get_connections(self, name: str) -> Optional[List[SfmConnection]]: return sfm_endpoint.connections - def send_message(self, endpoint: Endpoint, app_id: int, headers: Headers, payload: BytesAlike): + def send_message(self, endpoint: Endpoint, app_id: int, headers: Optional[dict], payload: BytesAlike): """Send a message to endpoint for app The message is asynchronous, no response is expected. @@ -429,13 +429,13 @@ def close_connection(self, connection: Connection): if old_state != state: self.notify_monitors(sfm_endpoint.endpoint) - def send_loopback_message(self, endpoint: Endpoint, app_id: int, headers: Headers, payload: BytesAlike): + def send_loopback_message(self, endpoint: Endpoint, app_id: int, headers: Optional[dict], payload: BytesAlike): """Send message to itself""" # Call receiver in a different thread to avoid deadlock self.frame_mgr_executor.submit(self.loopback_message_task, endpoint, app_id, headers, payload) - def loopback_message_task(self, endpoint: Endpoint, app_id: int, headers: Headers, payload: BytesAlike): + def loopback_message_task(self, endpoint: Endpoint, app_id: int, headers: Optional[dict], payload: BytesAlike): receiver = self.receivers.get(app_id) if not receiver: diff --git a/nvflare/fuel/f3/sfm/sfm_conn.py b/nvflare/fuel/f3/sfm/sfm_conn.py index 3f5a22f0e2..b8ee398cc7 100644 --- a/nvflare/fuel/f3/sfm/sfm_conn.py +++ b/nvflare/fuel/f3/sfm/sfm_conn.py @@ -21,7 +21,6 @@ from nvflare.fuel.f3.connection import BytesAlike, Connection from nvflare.fuel.f3.endpoint import Endpoint -from nvflare.fuel.f3.message import Headers from nvflare.fuel.f3.sfm.constants import HandshakeKeys, Types from nvflare.fuel.f3.sfm.prefix import PREFIX_LEN, Prefix @@ -88,7 +87,7 @@ def send_handshake(self, frame_type: int): self.send_dict(frame_type, 1, data) - def send_data(self, app_id: int, stream_id: int, headers: Headers, payload: BytesAlike): + def send_data(self, app_id: int, stream_id: int, headers: Optional[dict], payload: BytesAlike): """Send user data""" prefix = Prefix(0, 0, Types.DATA, 0, 0, app_id, stream_id, 0) @@ -102,7 +101,7 @@ def send_dict(self, frame_type: int, stream_id: int, data: dict): payload = msgpack.packb(data) self.send_frame(prefix, None, payload) - def send_frame(self, prefix: Prefix, headers: Optional[Headers], payload: Optional[BytesAlike]): + def send_frame(self, prefix: Prefix, headers: Optional[dict], payload: Optional[BytesAlike]): headers_bytes = self.headers_to_bytes(headers) header_len = len(headers_bytes) if headers_bytes else 0 diff --git a/nvflare/fuel/f3/stream_cell.py b/nvflare/fuel/f3/stream_cell.py new file mode 100644 index 0000000000..bfcfaad2f3 --- /dev/null +++ b/nvflare/fuel/f3/stream_cell.py @@ -0,0 +1,209 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable + +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.streaming.blob_streamer import BlobStreamer +from nvflare.fuel.f3.streaming.byte_receiver import ByteReceiver +from nvflare.fuel.f3.streaming.byte_streamer import ByteStreamer +from nvflare.fuel.f3.streaming.file_streamer import FileStreamer +from nvflare.fuel.f3.streaming.object_streamer import ObjectStreamer +from nvflare.fuel.f3.streaming.stream_types import ObjectIterator, ObjectStreamFuture, Stream, StreamError, StreamFuture + + +class StreamCell: + def __init__(self, cell: Cell): + self.cell = cell + self.byte_streamer = ByteStreamer(cell) + self.byte_receiver = ByteReceiver(cell) + self.blob_streamer = BlobStreamer(self.byte_streamer, self.byte_receiver) + self.file_streamer = FileStreamer(self.byte_streamer, self.byte_receiver) + self.object_streamer = ObjectStreamer(self.blob_streamer) + + @staticmethod + def get_chunk_size(): + """Get the default chunk size used by StreamCell + Byte stream are broken into chunks of this size before sending over Cellnet + """ + return ByteStreamer.get_chunk_size() + + def send_stream(self, channel: str, topic: str, target: str, message: Message) -> StreamFuture: + """ + Send a byte-stream over a channel/topic asynchronously. The streaming is performed in a different thread. + The streamer will read from stream and send the data in chunks till the stream reaches EOF. + + Args: + channel: channel for the stream + topic: topic for the stream + target: destination cell FQCN + message: The payload is the stream to send + + Returns: StreamFuture that can be used to check status/progress, or register callbacks. + The future result is the number of bytes sent + + """ + + if not isinstance(message.payload, Stream): + raise StreamError(f"Message payload is not a stream: {type(message.payload)}") + + return self.byte_streamer.send(channel, topic, target, message.headers, message.payload) + + def register_stream_cb(self, channel: str, topic: str, stream_cb: Callable, *args, **kwargs): + """ + Register a callback for reading stream. The stream_cb must have the following signature, + stream_cb(future: StreamFuture, stream: Stream, resume: bool, *args, **kwargs) -> int + future: The future represents the ongoing streaming. It's done when streaming is complete. + stream: The stream to read the receiving data from + resume: True if this is a restarted stream + It returns the offset to resume from if this is a restarted stream + + The resume_cb returns the offset to resume from: + resume_cb(stream_id: str, *args, **kwargs) -> int + + If None, the stream is not resumable. + + Args: + channel: the channel of the request + topic: topic of the request + stream_cb: The callback to handle the stream. This is called when a stream is started. It also + provides restart offset for restarted streams. This CB is invoked in a dedicated thread, + and it can block + *args: positional args to be passed to the callbacks + **kwargs: keyword args to be passed to the callbacks + + """ + self.byte_receiver.register_callback(channel, topic, stream_cb, *args, **kwargs) + + def send_blob(self, channel: str, topic: str, target: str, message: Message) -> StreamFuture: + """ + Send a BLOB (Binary Large Object) to the target. The payload of message is the BLOB. The BLOB must fit in + memory on the receiving end. + + Args: + channel: channel for the message + topic: topic of the message + target: destination cell IDs + message: the headers and the blob as payload + + Returns: StreamFuture that can be used to check status/progress and get result + The future result is the total number of bytes sent + + """ + + if not isinstance(message.payload, (bytes, bytearray, memoryview)): + raise StreamError(f"Message payload is not a byte array: {type(message.payload)}") + + return self.blob_streamer.send(channel, topic, target, message) + + def register_blob_cb(self, channel: str, topic: str, blob_cb, *args, **kwargs): + """ + Register a callback for receiving the blob. This callback is invoked when the whole + blob is received. If streaming fails, the streamer will try again. The failed streaming + is ignored. + + The callback must have the following signature, + blob_cb(future: StreamFuture, *args, **kwargs) + + The future's result is the final BLOB received + + Args: + channel: the channel of the request + topic: topic of the request + blob_cb: The callback to handle the stream + """ + self.blob_streamer.register_blob_callback(channel, topic, blob_cb, *args, **kwargs) + + def send_file(self, channel: str, topic: str, target: str, message: Message) -> StreamFuture: + """ + Send a file to target using stream API. + + Args: + channel: channel for the message + topic: topic for the message + target: destination cell FQCN + message: the headers and the full path of the file to be sent as payload + + Returns: StreamFuture that can be used to check status/progress and get the total bytes sent + + """ + if not isinstance(message.payload, str): + raise StreamError(f"Message payload is not a file name: {type(message.payload)}") + + file_name = message.payload + if not os.path.isfile(file_name) or not os.access(file_name, os.R_OK): + raise StreamError(f"File {file_name} doesn't exist or isn't readable") + + return self.file_streamer.send(channel, topic, target, message) + + def register_file_cb(self, channel: str, topic: str, file_cb, *args, **kwargs): + """ + Register callbacks for file receiving. The callbacks must have the following signatures, + file_cb(future: StreamFuture, file_name: str, *args, **kwargs) -> str + The future represents the file receiving task and the result is the final file path + It returns the full path where the file will be written to + + Args: + channel: the channel of the request + topic: topic of the request + file_cb: This CB is called when file transfer starts + """ + self.file_streamer.register_file_callback(channel, topic, file_cb, *args, **kwargs) + + def send_objects(self, channel: str, topic: str, target: str, message: Message) -> ObjectStreamFuture: + """ + Send a list of objects to the destination. Each object is sent as BLOB, so it must fit in memory + + Args: + channel: channel for the message + topic: topic of the message + target: destination cell IDs + message: Headers and the payload which is an iterator that provides next object + + Returns: ObjectStreamFuture that can be used to check status/progress, or register callbacks + """ + if not isinstance(message.payload, ObjectIterator): + raise StreamError(f"Message payload is not an object iterator: {type(message.payload)}") + + return self.object_streamer.stream_objects(channel, topic, target, message.headers, message.payload) + + def register_objects_cb( + self, channel: str, topic: str, object_stream_cb: Callable, object_cb: Callable, *args, **kwargs + ): + """ + Register callback for receiving the object. The callback signature is, + objects_stream_cb(future: ObjectStreamFuture, resume: bool, *args, **kwargs) -> int + future: It represents the streaming of all objects. An object CB can be registered with the future + to receive each object. + resume: True if this is a restarted stream + This CB returns the index to restart if this is a restarted stream + + object_cb(obj_sid: str, index: int, message: Message, *args, ** kwargs) + obj_sid: Object Stream ID + index: The index of the object + message: The header and payload is the object + + resume_cb(stream_id: str, *args, **kwargs) -> int + is received. The index starts from 0. The callback must have the following signature, + objects_cb(future: ObjectStreamFuture, index: int, object: Any, headers: Optional[dict], *args, **kwargs) + resume_cb(stream_id: str, *args, **kwargs) -> int + + Args: + channel: the channel of the request + topic: topic of the request + object_stream_cb: The callback when an object stream is started + object_cb: The callback is invoked when each object is received + """ + self.object_streamer.register_object_callbacks(channel, topic, object_stream_cb, object_cb, args, kwargs) diff --git a/nvflare/fuel/f3/streaming/__init__.py b/nvflare/fuel/f3/streaming/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/nvflare/fuel/f3/streaming/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fuel/f3/streaming/blob_streamer.py b/nvflare/fuel/f3/streaming/blob_streamer.py new file mode 100644 index 0000000000..7058afe6a1 --- /dev/null +++ b/nvflare/fuel/f3/streaming/blob_streamer.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Callable, Optional + +from nvflare.fuel.f3.connection import BytesAlike +from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.streaming.byte_receiver import ByteReceiver +from nvflare.fuel.f3.streaming.byte_streamer import ByteStreamer +from nvflare.fuel.f3.streaming.stream_const import EOS +from nvflare.fuel.f3.streaming.stream_types import Stream, StreamFuture +from nvflare.fuel.f3.streaming.stream_utils import stream_thread_pool, wrap_view +from nvflare.security.logging import secure_format_traceback + +log = logging.getLogger(__name__) + + +class BlobStream(Stream): + def __init__(self, blob: BytesAlike, headers: Optional[dict]): + super().__init__(len(blob), headers) + self.blob_view = wrap_view(blob) + + def read(self, chunk_size: int) -> BytesAlike: + + if self.pos >= self.get_size(): + return EOS + + next_pos = self.pos + chunk_size + if next_pos > self.get_size(): + next_pos = self.get_size() + buf = self.blob_view[self.pos : next_pos] + self.pos = next_pos + return buf + + +class BlobHandler: + def __init__(self, blob_cb: Callable): + self.blob_cb = blob_cb + self.size = 0 + self.buffer = None + + def handle_blob_cb(self, future: StreamFuture, stream: Stream, resume: bool, *args, **kwargs) -> int: + + if resume: + log.warning("Resume is not supported, ignored") + + self.size = stream.get_size() + + if self.size > 0: + self.buffer = bytearray(self.size) + else: + self.buffer = bytes() + + stream_thread_pool.submit(self._read_stream, future, stream) + + self.blob_cb(future, *args, **kwargs) + + return 0 + + def _read_stream(self, future: StreamFuture, stream: Stream): + + try: + chunk_size = ByteStreamer.get_chunk_size() + + buf_size = 0 + while True: + buf = stream.read(chunk_size) + if not buf: + break + + length = len(buf) + if self.size > 0: + self.buffer[buf_size : buf_size + length] = buf + else: + self.buffer += buf + + buf_size += length + + if self.size and (self.size != buf_size): + log.warning(f"Stream size doesn't match: {self.size} <> {buf_size}") + + future.set_result(self.buffer) + except Exception as ex: + log.error(f"Stream {future.get_stream_id()} read error: {ex}") + log.debug(secure_format_traceback()) + future.set_exception(ex) + + +class BlobStreamer: + def __init__(self, byte_streamer: ByteStreamer, byte_receiver: ByteReceiver): + self.byte_streamer = byte_streamer + self.byte_receiver = byte_receiver + + def send(self, channel: str, topic: str, target: str, message: Message) -> StreamFuture: + blob_stream = BlobStream(message.payload, message.headers) + return self.byte_streamer.send(channel, topic, target, message.headers, blob_stream) + + def register_blob_callback(self, channel, topic, blob_cb: Callable, *args, **kwargs): + handler = BlobHandler(blob_cb) + self.byte_receiver.register_callback(channel, topic, handler.handle_blob_cb, *args, **kwargs) diff --git a/nvflare/fuel/f3/streaming/byte_receiver.py b/nvflare/fuel/f3/streaming/byte_receiver.py new file mode 100644 index 0000000000..119aa51a8a --- /dev/null +++ b/nvflare/fuel/f3/streaming/byte_receiver.py @@ -0,0 +1,231 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import threading +from collections import deque +from typing import Callable, Dict + +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey +from nvflare.fuel.f3.cellnet.registry import Callback, Registry +from nvflare.fuel.f3.connection import BytesAlike +from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.streaming.stream_const import ( + EOS, + STREAM_ACK_TOPIC, + STREAM_CHANNEL, + STREAM_DATA_TOPIC, + StreamDataType, + StreamHeaderKey, +) +from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture +from nvflare.fuel.f3.streaming.stream_utils import stream_thread_pool + +log = logging.getLogger(__name__) + +MAX_OUT_SEQ_CHUNKS = 16 +# 1/4 of the window size +ACK_INTERVAL = 1024 * 1024 * 4 + + +class RxTask: + """Receiving task for ByteStream""" + + def __init__(self, sid: str, origin: str): + self.sid = sid + self.origin = origin + self.channel = None + self.topic = None + self.headers = None + self.size = 0 + + # The reassembled buffer in a double-ended queue + self.buffers = deque() + # Out-of-sequence buffers to be assembled + self.out_seq_buffers: Dict[int, BytesAlike] = {} + self.stream_future = None + self.next_seq = 0 + self.offset = 0 + self.offset_ack = 0 + self.eos = False + self.waiter = threading.Event() + + def __str__(self): + return f"[Rx{self.sid} from {self.origin} for {self.channel}/{self.topic}]" + + +class RxStream(Stream): + """A stream that's used to read streams from the buffer""" + + def __init__(self, cell: Cell, task: RxTask): + super().__init__(task.size, task.headers) + self.cell = cell + self.task = task + + def read(self, chunk_size: int) -> bytes: + if self.closed: + raise StreamError("Read from closed stream") + + if (not self.task.buffers) and self.task.eos: + return EOS + + # Block indefinitely if buffers are empty + if not self.task.buffers: + self.task.waiter.clear() + self.task.waiter.wait() + + buf = self.task.buffers.popleft() + if 0 < chunk_size < len(buf): + result = buf[0:chunk_size] + # Put leftover to the head of the queue + self.task.buffers.appendleft(buf[chunk_size:]) + else: + result = buf + + self.task.offset += len(buf) + if self.task.offset - self.task.offset_ack > ACK_INTERVAL: + # Send ACK + message = Message() + message.add_headers( + { + StreamHeaderKey.STREAM_ID: self.task.sid, + StreamHeaderKey.DATA_TYPE: StreamDataType.ACK, + StreamHeaderKey.OFFSET: self.task.offset, + } + ) + self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, self.task.origin, message) + self.task.offset_ack = self.task.offset + + self.task.stream_future.set_progress(self.task.offset) + + return result + + def close(self): + if not self.task.stream_future.done(): + self.task.stream_future.set_result(self.task.offset) + self.closed = True + + +class ByteReceiver: + def __init__(self, cell: Cell): + self.cell = cell + self.cell.register_request_cb(channel=STREAM_CHANNEL, topic=STREAM_DATA_TOPIC, cb=self._data_handler) + self.registry = Registry() + self.rx_task_map = {} + + def register_callback(self, channel: str, topic: str, stream_cb: Callable, *args, **kwargs): + if not callable(stream_cb): + raise StreamError(f"specified stream_cb {type(stream_cb)} is not callable") + + self.registry.set(channel, topic, Callback(stream_cb, args, kwargs)) + + def _data_handler(self, message: Message): + + sid = message.get_header(StreamHeaderKey.STREAM_ID) + origin = message.get_header(MessageHeaderKey.ORIGIN) + seq = message.get_header(StreamHeaderKey.SEQUENCE) + error = message.get_header(StreamHeaderKey.ERROR_MSG, None) + task = self.rx_task_map.get(sid, None) + if not task: + if error: + log.debug(f"Received error for non-existing stream: {sid} from {origin}") + return + + task = RxTask(sid, origin) + self.rx_task_map[sid] = task + + if error: + self._stop_task(task, StreamError(f"Received error from {origin}: {error}"), notify=False) + return + + if seq == 0: + # Handle new stream + task.channel = message.get_header(StreamHeaderKey.CHANNEL) + task.topic = message.get_header(StreamHeaderKey.TOPIC) + task.headers = message.headers + + task.stream_future = StreamFuture(sid, message.headers) + task.size = message.get_header(StreamHeaderKey.SIZE, 0) + task.stream_future.set_size(task.size) + + # Invoke callback + callback = self.registry.find(task.channel, task.topic) + if not callback: + self._stop_task(task, StreamError(f"No callback is registered for {task.channel}/{task.topic}")) + return + + stream_thread_pool.submit(self._callback_wrapper, task, callback) + + if seq == task.next_seq: + self._append(task, message.payload) + task.next_seq += 1 + + # Try to reassemble out-of-seq buffers + while task.next_seq in task.out_seq_buffers: + chunk = task.out_seq_buffers.pop(task.next_seq) + self._append(task, chunk) + task.next_seq += 1 + + else: + # Out-of-seq chunk reassembly + if len(task.out_seq_buffers) >= MAX_OUT_SEQ_CHUNKS: + self._stop_task(task, StreamError(f"Too many out-of-sequence chunks: {len(task.out_seq_buffers)}")) + return + else: + task.out_seq_buffers[seq] = message.payload + + data_type = message.get_header(StreamHeaderKey.DATA_TYPE) + if data_type == StreamDataType.FINAL: + # Task is not done till all buffers are read so future is not set here + self._stop_task(task) + + def _callback_wrapper(self, task: RxTask, callback: Callback): + """A wrapper to catch all exceptions in the callback""" + try: + stream = RxStream(self.cell, task) + return callback.cb(task.stream_future, stream, False, *callback.args, **callback.kwargs) + except Exception as ex: + msg = f"{task} callback {callback.cb} throws exception: {ex}" + log.error(msg) + self._stop_task(task, StreamError(msg)) + + @staticmethod + def _append(task: RxTask, buf: bytes): + if not buf: + return + + task.buffers.append(buf) + + # Wake up blocking read() + if not task.waiter.is_set(): + task.waiter.set() + + def _stop_task(self, task: RxTask, error: StreamError = None, notify=True): + if error: + log.error(f"Stream error: {error}") + task.stream_future.set_exception(error) + + if notify: + message = Message() + + message.add_headers( + { + StreamHeaderKey.STREAM_ID: task.sid, + StreamHeaderKey.DATA_TYPE: StreamDataType.ERROR, + StreamHeaderKey.ERROR_MSG: str(error), + } + ) + self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, task.origin, message) + + task.eos = True diff --git a/nvflare/fuel/f3/streaming/byte_streamer.py b/nvflare/fuel/f3/streaming/byte_streamer.py new file mode 100644 index 0000000000..91ada5504b --- /dev/null +++ b/nvflare/fuel/f3/streaming/byte_streamer.py @@ -0,0 +1,211 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import threading +from typing import Optional + +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey +from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.streaming.stream_const import ( + STREAM_ACK_TOPIC, + STREAM_CHANNEL, + STREAM_DATA_TOPIC, + StreamDataType, + StreamHeaderKey, +) +from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture +from nvflare.fuel.f3.streaming.stream_utils import gen_stream_id, stream_thread_pool, wrap_view + +STREAM_CHUNK_SIZE = 1024 * 1024 +STREAM_WINDOW_SIZE = 16 * STREAM_CHUNK_SIZE +STREAM_ACK_WAIT = 10 + +log = logging.getLogger(__name__) + + +class TxTask: + def __init__(self, channel: str, topic: str, target: str, headers: dict, stream: Stream): + self.sid = gen_stream_id() + self.buffer = bytearray(STREAM_CHUNK_SIZE) + # Optimization to send the original buffer without copying + self.direct_buf: Optional[bytes] = None + self.buffer_size = 0 + self.channel = channel + self.topic = topic + self.target = target + self.headers = headers + self.stream = stream + self.stream_future = None + self.task_future = None + self.ack_waiter = threading.Event() + self.seq = 0 + self.offset = 0 + self.offset_ack = 0 + self.stop = False + + def __str__(self): + return f"Tx[{self.sid} to {self.target} for {self.channel}/{self.topic}]" + + +class ByteStreamer: + def __init__(self, cell: Cell): + self.cell = cell + self.cell.register_request_cb(channel=STREAM_CHANNEL, topic=STREAM_ACK_TOPIC, cb=self._ack_handler) + self.tx_task_map = {} + + @staticmethod + def get_chunk_size(): + return STREAM_CHUNK_SIZE + + def send(self, channel: str, topic: str, target: str, headers: dict, stream: Stream) -> StreamFuture: + tx_task = TxTask(channel, topic, target, headers, stream) + self.tx_task_map[tx_task.sid] = tx_task + + future = StreamFuture(tx_task.sid) + future.set_size(stream.get_size()) + tx_task.stream_future = future + tx_task.task_future = stream_thread_pool.submit(self._transmit_task, tx_task) + + return future + + def _transmit_task(self, task: TxTask): + + while not task.stop: + buf = task.stream.read(STREAM_CHUNK_SIZE) + if not buf: + # End of Stream + self._transmit(task, final=True) + self._stop_task(task) + return + + # Flow control + window = task.offset - task.offset_ack + # It may take several ACKs to clear up the window + while window > STREAM_WINDOW_SIZE: + log.debug(f"{task} window size {window} exceeds limit: {STREAM_WINDOW_SIZE}") + task.ack_waiter.clear() + result = task.ack_waiter.wait(timeout=STREAM_ACK_WAIT) + if not result: + self._stop_task(task, StreamError(f"{task} ACK timeouts after {STREAM_ACK_WAIT} seconds")) + return + + if task.stop: + return + + window = task.offset - task.offset_ack + + size = len(buf) + if size > STREAM_CHUNK_SIZE: + raise StreamError(f"Stream returns invalid size: {size} for {task}") + if size + task.buffer_size > STREAM_CHUNK_SIZE: + self._transmit(task) + + if size == STREAM_CHUNK_SIZE: + task.direct_buf = buf + else: + task.buffer[task.buffer_size : task.buffer_size + size] = buf + task.buffer_size += size + + def _transmit(self, task: TxTask, final=False): + + if task.buffer_size == 0: + payload = None + elif task.buffer_size == STREAM_CHUNK_SIZE: + if task.direct_buf: + payload = task.direct_buf + else: + payload = task.buffer + else: + payload = wrap_view(task.buffer)[0 : task.buffer_size] + + message = Message(None, payload) + + if task.offset == 0: + # User headers are only included in the first chunk + if task.headers: + message.add_headers(task.headers) + + message.add_headers( + { + StreamHeaderKey.CHANNEL: task.channel, + StreamHeaderKey.TOPIC: task.topic, + } + ) + + message.add_headers( + { + StreamHeaderKey.STREAM_ID: task.sid, + StreamHeaderKey.DATA_TYPE: StreamDataType.FINAL if final else StreamDataType.CHUNK, + StreamHeaderKey.SEQUENCE: task.seq, + StreamHeaderKey.OFFSET: task.offset, + } + ) + + errors = self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_DATA_TOPIC, task.target, message) + error = errors.get(task.target) + if error: + msg = f"Message sending error to target {task.target}: {error}" + log.debug(msg) + self._stop_task(task, StreamError(msg)) + return + + # Update state + task.seq += 1 + task.offset += task.buffer_size + task.buffer_size = 0 + task.direct_buf = None + + # Update future + task.stream_future.set_progress(task.offset) + + def _stop_task(self, task: TxTask, error: StreamError = None, notify=True): + if error: + log.debug(f"Stream error: {error}") + task.stream_future.set_exception(error) + + if notify: + message = Message(None, None) + message.add_headers( + { + StreamHeaderKey.STREAM_ID: task.sid, + StreamHeaderKey.DATA_TYPE: StreamDataType.ERROR, + StreamHeaderKey.OFFSET: task.offset, + StreamHeaderKey.ERROR_MSG: str(error), + } + ) + self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_DATA_TOPIC, task.target, message) + else: + # Result is the number of bytes streamed + task.stream_future.set_result(task.offset) + task.stop = True + + def _ack_handler(self, message: Message): + origin = message.get_header(MessageHeaderKey.ORIGIN) + sid = message.get_header(StreamHeaderKey.STREAM_ID) + task = self.tx_task_map.get(sid, None) + if not task: + raise StreamError(f"Unknown stream ID {sid} received from {origin}") + + error = message.get_header(StreamHeaderKey.ERROR_MSG, None) + if error: + self._stop_task(task, StreamError(f"Received error from {origin}: {error}"), notify=False) + return + + offset = message.get_header(StreamHeaderKey.OFFSET, None) + if offset > task.offset_ack: + task.offset_ack = offset + + if not task.ack_waiter.is_set(): + task.ack_waiter.set() diff --git a/nvflare/fuel/f3/streaming/file_streamer.py b/nvflare/fuel/f3/streaming/file_streamer.py new file mode 100644 index 0000000000..7dd785f873 --- /dev/null +++ b/nvflare/fuel/f3/streaming/file_streamer.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from pathlib import Path +from typing import Callable, Optional + +from nvflare.fuel.f3.connection import BytesAlike +from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.streaming.byte_receiver import ByteReceiver +from nvflare.fuel.f3.streaming.byte_streamer import ByteStreamer +from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey +from nvflare.fuel.f3.streaming.stream_types import Stream, StreamFuture +from nvflare.fuel.f3.streaming.stream_utils import stream_thread_pool + +log = logging.getLogger(__name__) + + +class FileStream(Stream): + def __init__(self, file_name: str, headers: Optional[dict]): + self.file = open(file_name, "rb") + size = self.file.seek(0, os.SEEK_END) + self.file.seek(0, os.SEEK_SET) + super().__init__(size, headers) + + def read(self, chunk_size: int) -> BytesAlike: + return self.file.read(chunk_size) + + def close(self): + self.closed = True + self.file.close() + + +class FileHandler: + def __init__(self, file_cb: Callable): + self.file_cb = file_cb + self.size = 0 + self.file_name = None + + def handle_file_cb(self, future: StreamFuture, stream: Stream, resume: bool, *args, **kwargs) -> int: + + if resume: + log.warning("Resume is not supported, ignored") + + self.size = stream.get_size() + original_name = future.headers.get(StreamHeaderKey.FILE_NAME) + + file_name = self.file_cb(future, original_name, *args, **kwargs) + stream_thread_pool.submit(self._write_to_file, file_name, future, stream) + + return 0 + + def _write_to_file(self, file_name: str, future: StreamFuture, stream: Stream): + + file = open(file_name, "wb") + + chunk_size = ByteStreamer.get_chunk_size() + file_size = 0 + while True: + buf = stream.read(chunk_size) + if not buf: + break + + file_size += len(buf) + file.write(buf) + + file.close() + if self.size and (self.size != file_size): + log.warning(f"Size doesn't match: {self.size} <> {file_size}") + + future.set_result(file_name) + + +class FileStreamer: + def __init__(self, byte_streamer: ByteStreamer, byte_receiver: ByteReceiver): + self.byte_streamer = byte_streamer + self.byte_receiver = byte_receiver + + def send(self, channel: str, topic: str, target: str, message: Message) -> StreamFuture: + file_name = Path(message.payload).name + file_stream = FileStream(message.payload, message.headers) + + message.add_headers( + { + StreamHeaderKey.SIZE: file_stream.get_size(), + StreamHeaderKey.FILE_NAME: file_name, + } + ) + + return self.byte_streamer.send(channel, topic, target, message.headers, file_stream) + + def register_file_callback(self, channel, topic, file_cb: Callable, *args, **kwargs): + handler = FileHandler(file_cb) + self.byte_receiver.register_callback(channel, topic, handler.handle_file_cb, *args, **kwargs) diff --git a/nvflare/fuel/f3/streaming/object_streamer.py b/nvflare/fuel/f3/streaming/object_streamer.py new file mode 100644 index 0000000000..bcf0cdf411 --- /dev/null +++ b/nvflare/fuel/f3/streaming/object_streamer.py @@ -0,0 +1,129 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Callable, Optional + +from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey +from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.streaming.blob_streamer import BlobStreamer +from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey +from nvflare.fuel.f3.streaming.stream_types import ObjectIterator, ObjectStreamFuture, StreamFuture +from nvflare.fuel.f3.streaming.stream_utils import gen_stream_id, stream_thread_pool + +log = logging.getLogger(__name__) + + +class ObjectTxTask: + def __init__(self, channel: str, topic: str, target: str, headers: dict, iterator: ObjectIterator): + self.obj_sid = gen_stream_id() + self.index = 0 + self.channel = channel + self.topic = topic + self.target = target + self.headers = headers if headers else {} + self.iterator = iterator + self.object_future = None + self.stop = False + + def __str__(self): + return f"ObjTx[{self.obj_sid}/{self.index} to {self.target} for {self.channel}/{self.topic}]" + + +class ObjectRxTask: + def __init__(self, obj_sid: str, channel: str, topic: str, origin: str, headers: dict): + self.obj_sid = obj_sid + self.index = 0 + self.channel = channel + self.topic = topic + self.origin = origin + self.headers = headers + self.object_future: Optional[ObjectStreamFuture] = None + + def __str__(self): + return f"ObjRx[{self.obj_sid}/{self.index} from {self.origin} for {self.channel}/{self.topic}]" + + +class ObjectHandler: + def __init__(self, object_stream_cb: Callable, object_cb: Callable, obj_tasks: dict): + self.object_stream_cb = object_stream_cb + self.object_cb = object_cb + self.obj_tasks = obj_tasks + + def object_done(self, future: StreamFuture, obj_sid: str, index: int, *args, **kwargs): + blob = future.result() + self.object_cb(obj_sid, index, Message(future.get_headers(), blob), *args, **kwargs) + + def handle_object(self, future: StreamFuture, *args, **kwargs): + headers = future.get_headers() + obj_sid = headers.get(StreamHeaderKey.OBJECT_STREAM_ID, None) + + if not obj_sid: + return + + task = self.obj_tasks.get(obj_sid, None) + if not task: + # Handle new object stream + origin = headers.get(MessageHeaderKey.ORIGIN) + channel = headers.get(StreamHeaderKey.CHANNEL) + topic = headers.get(StreamHeaderKey.TOPIC) + task = ObjectRxTask(obj_sid, channel, topic, origin, headers) + task.object_future = ObjectStreamFuture(obj_sid, headers) + + stream_thread_pool.submit(self.object_stream_cb, task.object_future, *args, **kwargs) + + task.object_future.set_index(task.index) + task.index += 1 + future.add_done_callback(self.object_done, future, task.obj_sid, task.index) + + +class ObjectStreamer: + def __init__(self, blob_streamer: BlobStreamer): + self.blob_streamer = blob_streamer + self.obj_tasks = {} + + def stream_objects( + self, channel: str, topic: str, target: str, headers: dict, iterator: ObjectIterator + ) -> ObjectStreamFuture: + tx_task = ObjectTxTask(channel, topic, target, headers, iterator) + tx_task.object_future = ObjectStreamFuture(tx_task.obj_sid, headers) + stream_thread_pool.submit(self._streaming_task, tx_task) + + return tx_task.object_future + + def register_object_callbacks( + self, channel, topic, object_stream_cb: Callable, object_cb: Callable, *args, **kwargs + ): + handler = ObjectHandler(object_stream_cb, object_cb, self.obj_tasks) + self.blob_streamer.register_blob_callback(channel, topic, handler.handle_object, *args, **kwargs) + + def _streaming_task(self, task: ObjectTxTask): + + for obj in task.iterator: + + task.object_future.set_index(task.index) + + task.headers.update( + { + StreamHeaderKey.OBJECT_STREAM_ID: task.obj_sid, + StreamHeaderKey.OBJECT_INDEX: task.index, + } + ) + blob_future = self.blob_streamer.send(task.channel, task.topic, task.target, task.headers, obj) + + # Wait till it's done + bytes_sent = blob_future.result() + log.debug(f"Stream {task.obj_sid} Object {task.index} is sent ({bytes_sent}") + task.index += 1 + + task.object_future.set_result(task.index) diff --git a/nvflare/fuel/f3/streaming/stream_const.py b/nvflare/fuel/f3/streaming/stream_const.py new file mode 100644 index 0000000000..9e5910a9fd --- /dev/null +++ b/nvflare/fuel/f3/streaming/stream_const.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +STREAM_PREFIX = "sm__" +STREAM_CHANNEL = STREAM_PREFIX + "STREAM" +STREAM_DATA_TOPIC = STREAM_PREFIX + "DATA" +STREAM_ACK_TOPIC = STREAM_PREFIX + "ACK" +# End of Stream indicator +EOS = bytes() + + +class StreamDataType: + # Payload chunk + CHUNK = 1 + # Final chunk, end of stream + FINAL = 2 + # ACK with last received offset + ACK = 3 + # Resume request + RESUME = 4 + # Resume ack with offset to start + RESUME_ACK = 5 + # Streaming failed + ERROR = 6 + + +class StreamHeaderKey: + + # Try to keep the key small to reduce the overhead + STREAM_ID = STREAM_PREFIX + "id" + DATA_TYPE = STREAM_PREFIX + "dt" + SIZE = STREAM_PREFIX + "sz" + SEQUENCE = STREAM_PREFIX + "sq" + OFFSET = STREAM_PREFIX + "os" + ERROR_MSG = STREAM_PREFIX + "em" + CHANNEL = STREAM_PREFIX + "ch" + FILE_NAME = STREAM_PREFIX + "fn" + TOPIC = STREAM_PREFIX + "tp" + OBJECT_STREAM_ID = STREAM_PREFIX + "os" + OBJECT_INDEX = STREAM_PREFIX + "oi" diff --git a/nvflare/fuel/f3/streaming/stream_types.py b/nvflare/fuel/f3/streaming/stream_types.py new file mode 100644 index 0000000000..266069dcf8 --- /dev/null +++ b/nvflare/fuel/f3/streaming/stream_types.py @@ -0,0 +1,271 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import threading +from abc import ABC, abstractmethod +from collections.abc import Iterator +from typing import Any, Callable, Optional + +from nvflare.fuel.f3.connection import BytesAlike +from nvflare.fuel.f3.streaming.stream_utils import gen_stream_id + +log = logging.getLogger(__name__) + + +class StreamError(Exception): + """All stream API throws this error""" + + pass + + +class StreamCancelled(StreamError): + """Streaming is cancelled by sender""" + + pass + + +class Stream(ABC): + """A raw, read-only, seekable binary byte stream""" + + def __init__(self, size: int = 0, headers: Optional[dict] = None): + """Constructor for stream + + Args: + size: The total size of stream. 0 if unknown + headers: Optional headers to be passed to the receiver + """ + self.size = size + self.pos = 0 + self.headers = headers + self.closed = False + + def get_size(self) -> int: + return self.size + + def get_pos(self): + return self.pos + + def get_headers(self) -> Optional[dict]: + return self.headers + + @abstractmethod + def read(self, chunk_size: int) -> BytesAlike: + """Read and return up to chunk_size bytes. It can return less but not more than the chunk_size. + An empty bytes object is returned if the stream reaches the end. + + Args: + chunk_size: Up to (but maybe less) this many bytes will be returned + + Returns: + Binary data. If empty, it means the stream is depleted (EOF) + """ + pass + + def close(self): + """Close the stream""" + self.closed = True + + def seek(self, offset: int): + """Change the stream position to the given byte offset. + Args: + offset: Offset relative to the start of the stream + + Exception: + StreamError: If the stream is not seekable + """ + self.pos = offset + + +class ObjectIterator(Iterator, ABC): + """An object iterator that returns next object + The __next__() method must be defined to return next object. + """ + + def __init__(self, headers: Optional[dict] = None): + self.sid = gen_stream_id() + self.headers = headers + self.index = 0 + + def get_headers(self) -> Optional[dict]: + return self.headers + + def stream_id(self): + return self.sid + + def get_index(self) -> int: + return self.index + + def set_index(self, index: int): + self.index = index + + +class StreamFuture: + """Future class for all stream calls. + + Fashioned after concurrent.futures.Future + """ + + def __init__(self, stream_id: str, headers: Optional[dict] = None): + self.stream_id = stream_id + self.headers = headers + self.waiter = threading.Event() + self.lock = threading.Lock() + self.error: Optional[StreamError] = None + self.value = None + self.size = 0 + self.progress = 0 + self.done_callbacks = [] + + def get_stream_id(self) -> str: + return self.stream_id + + def get_headers(self) -> Optional[dict]: + return self.headers + + def get_size(self) -> int: + return self.size + + def set_size(self, size: int): + self.size = size + + def get_progress(self) -> int: + return self.progress + + def set_progress(self, progress: int): + self.progress = progress + + def cancel(self): + """Cancel the future if possible. + + Returns True if the future was cancelled, False otherwise. A future + cannot be cancelled if it is running or has already completed. + """ + + with self.lock: + if self.error or self.result: + return False + + self.error = StreamCancelled("Stream is cancelled") + + return True + + def cancelled(self): + with self.lock: + return isinstance(self.error, StreamCancelled) + + def running(self): + """Return True if the future is currently executing.""" + with self.lock: + return not self.waiter.isSet() + + def done(self): + """Return True of the future was cancelled or finished executing.""" + with self.lock: + return self.error or self.waiter.isSet() + + def add_done_callback(self, done_cb: Callable, *args, **kwargs): + """Attaches a callable that will be called when the future finishes. + + Args: + done_cb: A callable that will be called with this future completes + """ + with self.lock: + self.done_callbacks.append((done_cb, args, kwargs)) + + def result(self, timeout=None) -> Any: + """Return the result of the call that the future represents. + + Args: + timeout: The number of seconds to wait for the result if the future + isn't done. If None, then there is no limit on the wait time. + + Returns: + The final result + + Raises: + CancelledError: If the future was cancelled. + TimeoutError: If the future didn't finish executing before the given + timeout. + """ + + self.waiter.wait(timeout) + + if self.error: + raise self.error + + return self.value + + def exception(self, timeout=None): + """Return the exception raised by the call that the future represents. + + Args: + timeout: The number of seconds to wait for the exception if the + future isn't done. If None, then there is no limit on the wait + time. + + Returns: + The exception raised by the call that the future represents or None + if the call completed without raising. + + Raises: + CancelledError: If the future was cancelled. + TimeoutError: If the future didn't finish executing before the given + timeout. + """ + + self.waiter.wait(timeout) + return self.error + + def set_result(self, value: Any): + """Sets the return value of work associated with the future.""" + + with self.lock: + if self.error: + raise StreamError("Invalid state, future already failed") + self.value = value + self.waiter.set() + + self._invoke_callbacks() + + def set_exception(self, exception): + """Sets the result of the future as being the given exception.""" + with self.lock: + self.error = exception + self.waiter.set() + + self._invoke_callbacks() + + def _invoke_callbacks(self): + for callback, args, kwargs in self.done_callbacks: + try: + callback(self, args, kwargs) + except Exception as ex: + log.error(f"Exception calling callback for {callback}: {ex}") + + +class ObjectStreamFuture(StreamFuture): + def __init__(self, stream_id: str, headers: Optional[dict] = None): + super().__init__(stream_id, headers) + self.index = 0 + + def get_index(self) -> int: + """Current object index, which is only available for ObjectStream""" + return self.index + + def set_index(self, index: int): + """Set current object index""" + self.index = index + + def get_progress(self): + return self.index diff --git a/nvflare/fuel/f3/streaming/stream_utils.py b/nvflare/fuel/f3/streaming/stream_utils.py new file mode 100644 index 0000000000..4942fea922 --- /dev/null +++ b/nvflare/fuel/f3/streaming/stream_utils.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time +from concurrent.futures import ThreadPoolExecutor + +from nvflare.fuel.f3.connection import BytesAlike + +STREAM_THREAD_POOL_SIZE = 128 + +stream_thread_pool = ThreadPoolExecutor(STREAM_THREAD_POOL_SIZE, "stm") +lock = threading.Lock() +start_time = time.time() * 1000000 # microseconds +stream_count = 0 + + +def wrap_view(buffer: BytesAlike) -> memoryview: + if isinstance(buffer, memoryview): + view = buffer + else: + view = memoryview(buffer) + + return view + + +def gen_stream_id(): + global lock, stream_count, start_time + with lock: + stream_count += 1 + return f"SID{(start_time + stream_count):16.0f}" diff --git a/nvflare/fuel/f3/streaming/tools/__init__.py b/nvflare/fuel/f3/streaming/tools/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/nvflare/fuel/f3/streaming/tools/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fuel/f3/streaming/tools/file_receiver.py b/nvflare/fuel/f3/streaming/tools/file_receiver.py new file mode 100644 index 0000000000..5654eb765f --- /dev/null +++ b/nvflare/fuel/f3/streaming/tools/file_receiver.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import sys +import time + +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.stream_cell import StreamCell +from nvflare.fuel.f3.streaming.stream_types import StreamFuture +from nvflare.fuel.f3.streaming.stream_utils import stream_thread_pool +from nvflare.fuel.f3.streaming.tools.utils import RX_CELL, TEST_CHANNEL, TEST_TOPIC, setup_log + + +class FileReceiver: + """Utility to receive files sent from another cell""" + + def __init__(self, listening_url: str, out_folder: str): + self.cell = Cell(RX_CELL, listening_url, secure=False, credentials={}) + self.stream_cell = StreamCell(self.cell) + self.stream_cell.register_file_cb(TEST_CHANNEL, TEST_TOPIC, self.file_cb) + self.cell.start() + self.out_folder = out_folder + self.file_received = 0 + + def stop(self): + self.cell.stop() + + def file_cb(self, future: StreamFuture, original_name: str): + out_file = os.path.join(self.out_folder, original_name) + stream_thread_pool.submit(self.monitor_status, future) + print(f"Received file {original_name}, writing to {out_file} ...") + return out_file + + def monitor_status(self, future: StreamFuture): + + start = time.time() + + while True: + if future.done(): + break + + progress = future.get_progress() + percent = progress * 100.0 / future.get_size() + print(f"Received {progress} bytes {percent:.2f}% done") + time.sleep(1) + + name = future.result() + print(f"Time elapsed: {(time.time() - start):.3f} seconds") + print(f"File {name} is sent") + self.file_received += 1 + + return name + + +if __name__ == "__main__": + setup_log(logging.INFO) + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} listening_url out_folder") + sys.exit(1) + + listening_url = sys.argv[1] + out_folder = sys.argv[2] + + receiver = FileReceiver(listening_url, out_folder) + + while True: + if receiver.file_received >= 1: + break + time.sleep(1) + + receiver.stop() + print(f"Done. Files received: {receiver.file_received}") diff --git a/nvflare/fuel/f3/streaming/tools/file_sender.py b/nvflare/fuel/f3/streaming/tools/file_sender.py new file mode 100644 index 0000000000..75a8042725 --- /dev/null +++ b/nvflare/fuel/f3/streaming/tools/file_sender.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import sys +import threading +import time + +from nvflare.fuel.f3.cellnet.cell import Cell, CellAgent +from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.stream_cell import StreamCell +from nvflare.fuel.f3.streaming.tools.utils import RX_CELL, TEST_CHANNEL, TEST_TOPIC, TX_CELL, setup_log + + +class FileSender: + """Utility to send a file to another cell""" + + def __init__(self, url: str): + core_cell = Cell(TX_CELL, url, secure=False, credentials={}) + self.stream_cell = StreamCell(core_cell) + core_cell.set_cell_connected_cb(self.cell_connected) + core_cell.start() + self.cell = core_cell + self.ready = threading.Event() + + def stop(self): + self.cell.stop() + + def wait(self): + self.ready.wait() + + def send(self, file_to_send: str): + future = self.stream_cell.send_file(TEST_CHANNEL, TEST_TOPIC, RX_CELL, Message(None, file_to_send)) + + while True: + if future.done(): + break + + time.sleep(1) + progress = future.get_progress() + percent = progress * 100.0 / future.get_size() + print(f"Sent {progress} bytes {percent:.2f}% done") + + size = future.result() + print(f"Total {size} bytes sent for file {file_to_send}") + + def cell_connected(self, agent: CellAgent): + if agent.get_fqcn() == RX_CELL: + self.ready.set() + + +if __name__ == "__main__": + setup_log(logging.INFO) + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} connect_url file_name") + sys.exit(1) + + connect_url = sys.argv[1] + file_name = sys.argv[2] + sender = FileSender(connect_url) + print("Waiting for receiver to be online ...") + sender.wait() + print(f"Sending file {file_name} ...") + + start = time.time() + sender.send(file_name) + print(f"Time elapsed: {(time.time()-start):.3f} seconds") + + sender.stop() + print("Done") diff --git a/nvflare/fuel/f3/streaming/tools/receiver.py b/nvflare/fuel/f3/streaming/tools/receiver.py new file mode 100644 index 0000000000..f8c3ca1f07 --- /dev/null +++ b/nvflare/fuel/f3/streaming/tools/receiver.py @@ -0,0 +1,65 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time + +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.stream_cell import StreamCell +from nvflare.fuel.f3.streaming.stream_types import StreamFuture +from nvflare.fuel.f3.streaming.tools.utils import BUF_SIZE, RX_CELL, TEST_CHANNEL, TEST_TOPIC, make_buffer, setup_log + + +class Receiver: + """Test BLOB receiving""" + + def __init__(self, listening_url: str): + cell = Cell(RX_CELL, listening_url, secure=False, credentials={}) + cell.start() + self.stream_cell = StreamCell(cell) + self.stream_cell.register_blob_cb(TEST_CHANNEL, TEST_TOPIC, self.blob_cb) + self.futures = {} + + def get_futures(self) -> dict: + return self.futures + + def blob_cb(self, stream_future: StreamFuture, *args, **kwargs): + sid = stream_future.get_stream_id() + print(f"Stream {sid} received") + self.futures[sid] = stream_future + + +if __name__ == "__main__": + setup_log(logging.DEBUG) + url = "tcp://localhost:1234" + receiver = Receiver(url) + time.sleep(2) + result = None + while True: + if receiver.get_futures: + for sid, fut in receiver.get_futures().items(): + if fut.done(): + result = fut.result() + break + else: + print(f"{sid} Progress: {fut.get_progress()}") + time.sleep(1) + if result: + break + + buffer = make_buffer(BUF_SIZE) + + if buffer == result: + print("Result is correct") + else: + print("Result is wrong") diff --git a/nvflare/fuel/f3/streaming/tools/sender.py b/nvflare/fuel/f3/streaming/tools/sender.py new file mode 100644 index 0000000000..ee9fb7de6b --- /dev/null +++ b/nvflare/fuel/f3/streaming/tools/sender.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time + +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.stream_cell import StreamCell +from nvflare.fuel.f3.streaming.stream_types import StreamFuture +from nvflare.fuel.f3.streaming.tools.utils import ( + BUF_SIZE, + RX_CELL, + TEST_CHANNEL, + TEST_TOPIC, + TX_CELL, + make_buffer, + setup_log, +) + + +class Sender: + """Test BLOB sending""" + + def __init__(self, url: str): + core_cell = Cell(TX_CELL, url, secure=False, credentials={}) + self.stream_cell = StreamCell(core_cell) + core_cell.start() + + def send(self, blob: bytes) -> StreamFuture: + return self.stream_cell.send_blob(TEST_CHANNEL, TEST_TOPIC, RX_CELL, Message(None, blob)) + + +if __name__ == "__main__": + setup_log(logging.INFO) + connect_url = "tcp://localhost:1234" + sender = Sender(connect_url) + time.sleep(2) + + buffer = make_buffer(BUF_SIZE) + fut = sender.send(buffer) + n = fut.result() + print(f"Bytes sent: {n}") diff --git a/nvflare/fuel/f3/streaming/tools/utils.py b/nvflare/fuel/f3/streaming/tools/utils.py new file mode 100644 index 0000000000..5d2d59c97c --- /dev/null +++ b/nvflare/fuel/f3/streaming/tools/utils.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +BUF_SIZE = 64 * 1024 * 1024 + 1 +TEST_CHANNEL = "stream" +TEST_TOPIC = "test" +TX_CELL = "sender" +RX_CELL = "server" + + +def make_buffer(size: int) -> bytearray: + + buf = bytearray(size) + buf_len = 0 + n = 0 + while True: + temp = n.to_bytes(8, "big", signed=False) + temp_len = len(temp) + if (buf_len + temp_len) > size: + temp_len = size - buf_len + buf[buf_len : buf_len + temp_len] = temp[0:temp_len] + buf_len += temp_len + n += 1 + if buf_len >= size: + break + + return buf + + +def setup_log(level): + logging.basicConfig(level=level) + formatter = logging.Formatter( + fmt="%(relativeCreated)6d [%(threadName)-12s] [%(levelname)-5s] %(name)s: %(message)s" + ) + handler = logging.StreamHandler() + handler.setFormatter(formatter) + root_log = logging.getLogger() + root_log.handlers.clear() + root_log.addHandler(handler) diff --git a/nvflare/private/aux_runner.py b/nvflare/private/aux_runner.py index 215207c418..fcacc18cf6 100644 --- a/nvflare/private/aux_runner.py +++ b/nvflare/private/aux_runner.py @@ -23,7 +23,6 @@ from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply from nvflare.fuel.f3.cellnet.cell import Message, MessageHeaderKey from nvflare.fuel.f3.cellnet.cell import ReturnCode as CellReturnCode -from nvflare.fuel.f3.cellnet.cell import new_message from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.private.defs import CellChannel from nvflare.security.logging import secure_format_traceback @@ -251,7 +250,7 @@ def _send_to_cell( for name in target_names: target_fqcns.append(FQCN.join([name, job_id])) - cell_msg = new_message(payload=request) + cell_msg = Message(payload=request) if timeout > 0: cell_replies = cell.broadcast_request( channel=channel, topic=topic, request=cell_msg, targets=target_fqcns, timeout=timeout, optional=optional diff --git a/nvflare/private/defs.py b/nvflare/private/defs.py index c712b598f8..edcbed3be9 100644 --- a/nvflare/private/defs.py +++ b/nvflare/private/defs.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nvflare.fuel.f3.message import Headers, Message +from nvflare.fuel.f3.message import Message from nvflare.fuel.hci.server.constants import ConnProps @@ -177,7 +177,7 @@ class CellMessageHeaderKeys: def new_cell_message(headers: dict, payload=None): - msg_headers = Headers() + msg_headers = {} if headers: msg_headers.update(headers) return Message(msg_headers, payload) diff --git a/nvflare/private/fed/cmi.py b/nvflare/private/fed/cmi.py index 050f9f1639..efbfcea3cd 100644 --- a/nvflare/private/fed/cmi.py +++ b/nvflare/private/fed/cmi.py @@ -20,7 +20,6 @@ from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply from nvflare.fuel.f3.cellnet.cell import FQCN, Cell, Message, MessageHeaderKey from nvflare.fuel.f3.cellnet.cell import ReturnCode as CellReturnCode -from nvflare.fuel.f3.cellnet.cell import new_message from nvflare.private.defs import CellMessageHeaderKeys @@ -66,7 +65,7 @@ def __init__( self.cell.add_incoming_reply_filter(channel="*", topic="*", cb=self._filter_incoming_message) def new_cmi_message(self, fl_ctx: FLContext, headers=None, payload=None): - msg = new_message(headers, payload) + msg = Message(headers, payload) msg.set_prop(self.PROP_KEY_FL_CTX, fl_ctx) return msg diff --git a/tests/unit_test/fuel/f3/streaming/__init__.py b/tests/unit_test/fuel/f3/streaming/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/tests/unit_test/fuel/f3/streaming/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/fuel/f3/streaming/streaming_test.py b/tests/unit_test/fuel/f3/streaming/streaming_test.py new file mode 100644 index 0000000000..6d3ad18adf --- /dev/null +++ b/tests/unit_test/fuel/f3/streaming/streaming_test.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading + +import pytest + +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.stream_cell import StreamCell +from nvflare.fuel.f3.streaming.stream_types import StreamFuture +from nvflare.fuel.f3.streaming.tools.utils import RX_CELL, TEST_CHANNEL, TEST_TOPIC, TX_CELL, make_buffer +from nvflare.fuel.utils.network_utils import get_open_ports + +WAIT_SEC = 10 + + +class State: + def __init__(self): + self.done = threading.Event() + self.result = None + + +class TestStreamCell: + @pytest.fixture + def port(self): + return get_open_ports(1)[0] + + @pytest.fixture + def state(self): + return State() + + @pytest.fixture + def server_cell(self, port, state): + listening_url = f"tcp://localhost:{port}" + cell = Cell(RX_CELL, listening_url, secure=False, credentials={}) + stream_cell = StreamCell(cell) + stream_cell.register_blob_cb(TEST_CHANNEL, TEST_TOPIC, self.blob_cb, state=state) + cell.start() + + return stream_cell + + @pytest.fixture + def client_cell(self, port, state): + connect_url = f"tcp://localhost:{port}" + cell = Cell(TX_CELL, connect_url, secure=False, credentials={}) + stream_cell = StreamCell(cell) + cell.start() + + return stream_cell + + def test_streaming_blob(self, server_cell, client_cell, state): + + size = 64 * 1024 * 1024 + 123 + buffer = make_buffer(size) + + send_future = client_cell.send_blob(TEST_CHANNEL, TEST_TOPIC, RX_CELL, Message(None, buffer)) + bytes_sent = send_future.result() + assert bytes_sent == len(buffer) + + if not state.done.wait(timeout=30): + raise Exception("Data not received after 30 seconds") + + assert buffer == state.result + + def blob_cb(self, future: StreamFuture, **kwargs): + state = kwargs.get("state") + state.result = future.result() + state.done.set()