Skip to content

Commit

Permalink
FCI Stream API (#1758)
Browse files Browse the repository at this point in the history
* Added more detail when recursive data is found in FOBS

* Stream API

* Changed API to use future and message

* Made object_cb required

* Revamped to new StreamCell API

* Added handling for the special case when first chunk arrives late

* Increased the reassembly buffer size

* Added tools

* Fixed several dead-lock errors

* Added unit test for streaming and utility to send files

* Addressed issues in the PR
  • Loading branch information
nvidianz authored Jun 23, 2023
1 parent 8b10e7c commit 4eaa7df
Show file tree
Hide file tree
Showing 29 changed files with 1,946 additions and 142 deletions.
89 changes: 24 additions & 65 deletions nvflare/fuel/f3/cellnet/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
ServiceUnavailable,
)
from nvflare.fuel.f3.cellnet.fqcn import FQCN, FqcnInfo, same_family
from nvflare.fuel.f3.cellnet.utils import decode_payload, encode_payload, format_log_message, make_reply, new_message
from nvflare.fuel.f3.cellnet.registry import Callback, Registry
from nvflare.fuel.f3.cellnet.utils import decode_payload, encode_payload, format_log_message, make_reply
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.communicator import Communicator, MessageReceiver
from nvflare.fuel.f3.connection import Connection
Expand Down Expand Up @@ -86,7 +87,7 @@ def to_dict(self):
@staticmethod
def from_dict(d: dict):
msg_dict = d.get("message")
msg = new_message(headers=msg_dict.get("headers"), payload=msg_dict.get("payload"))
msg = Message(headers=msg_dict.get("headers"), payload=msg_dict.get("payload"))
return TargetMessage(target=d.get("target"), channel=d.get("channel"), topic=d.get("topic"), message=msg)


Expand All @@ -110,46 +111,6 @@ def get_fqcn(self):
return self.info.fqcn


class _CB:
def __init__(self, cb, args, kwargs):
self.cb = cb
self.args = args
self.kwargs = kwargs


class _Registry:
def __init__(self):
self.reg = {} # channel/topic => _CB

@staticmethod
def _item_key(channel: str, topic: str) -> str:
return f"{channel}:{topic}"

def set(self, channel: str, topic: str, items):
key = self._item_key(channel, topic)
self.reg[key] = items

def append(self, channel: str, topic: str, items):
key = self._item_key(channel, topic)
item_list = self.reg.get(key)
if not item_list:
item_list = []
self.reg[key] = item_list
item_list.append(items)

def find(self, channel: str, topic: str):
items = self.reg.get(self._item_key(channel, topic))
if not items:
# try topic * in channel
items = self.reg.get(self._item_key(channel, "*"))

if not items:
# try topic * in channel *
items = self.reg.get(self._item_key("*", "*"))

return items


class _Waiter(threading.Event):
def __init__(self, targets: List[str]):
super().__init__()
Expand Down Expand Up @@ -215,7 +176,7 @@ def send(self):
f"{self.cell.get_fqcn()}: bulk sender {self.target} sending bulk size {len(messages_to_send)}"
)
tms = [m.to_dict() for m in messages_to_send]
bulk_msg = new_message(payload=tms)
bulk_msg = Message(None, tms)
send_errs = self.cell.fire_and_forget(
channel=_CHANNEL, topic=_TOPIC_BULK, targets=[self.target], message=bulk_msg
)
Expand Down Expand Up @@ -380,12 +341,12 @@ def __init__(

self.communicator.register_message_receiver(app_id=self.APP_ID, receiver=self)
self.communicator.register_monitor(monitor=self)
self.req_reg = _Registry()
self.in_req_filter_reg = _Registry() # for request received
self.out_reply_filter_reg = _Registry() # for reply going out
self.out_req_filter_reg = _Registry() # for request sent
self.in_reply_filter_reg = _Registry() # for reply received
self.error_handler_reg = _Registry()
self.req_reg = Registry()
self.in_req_filter_reg = Registry() # for request received
self.out_reply_filter_reg = Registry() # for reply going out
self.out_req_filter_reg = Registry() # for request sent
self.in_reply_filter_reg = Registry() # for reply received
self.error_handler_reg = Registry()
self.cell_connected_cb = None
self.cell_connected_cb_args = None
self.cell_connected_cb_kwargs = None
Expand Down Expand Up @@ -866,7 +827,7 @@ def stop(self):
channel=_CHANNEL,
topic=_TOPIC_BYE,
targets=targets,
request=new_message(),
request=Message(),
timeout=0.5,
optional=True,
)
Expand Down Expand Up @@ -905,39 +866,39 @@ def register_request_cb(self, channel: str, topic: str, cb, *args, **kwargs):
"""
if not callable(cb):
raise ValueError(f"specified request_cb {type(cb)} is not callable")
self.req_reg.set(channel, topic, _CB(cb, args, kwargs))
self.req_reg.set(channel, topic, Callback(cb, args, kwargs))

def add_incoming_request_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified incoming_request_filter {type(cb)} is not callable")
self.in_req_filter_reg.append(channel, topic, _CB(cb, args, kwargs))
self.in_req_filter_reg.append(channel, topic, Callback(cb, args, kwargs))

def add_outgoing_reply_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified outgoing_reply_filter {type(cb)} is not callable")
self.out_reply_filter_reg.append(channel, topic, _CB(cb, args, kwargs))
self.out_reply_filter_reg.append(channel, topic, Callback(cb, args, kwargs))

def add_outgoing_request_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified outgoing_request_filter {type(cb)} is not callable")
self.out_req_filter_reg.append(channel, topic, _CB(cb, args, kwargs))
self.out_req_filter_reg.append(channel, topic, Callback(cb, args, kwargs))

def add_incoming_reply_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified incoming_reply_filter {type(cb)} is not callable")
self.in_reply_filter_reg.append(channel, topic, _CB(cb, args, kwargs))
self.in_reply_filter_reg.append(channel, topic, Callback(cb, args, kwargs))

def add_error_handler(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified error_handler {type(cb)} is not callable")
self.error_handler_reg.set(channel, topic, _CB(cb, args, kwargs))
self.error_handler_reg.set(channel, topic, Callback(cb, args, kwargs))

def _filter_outgoing_request(self, channel: str, topic: str, request: Message) -> Union[None, Message]:
cbs = self.out_req_filter_reg.find(channel, topic)
if not cbs:
return None
for _cb in cbs:
assert isinstance(_cb, _CB)
assert isinstance(_cb, Callback)
reply = self._try_cb(request, _cb.cb, *_cb.args, **_cb.kwargs)
if reply:
return reply
Expand Down Expand Up @@ -1112,7 +1073,7 @@ def _send_target_messages(
self.logger.debug(f"{self.my_info.fqcn}: invoking outgoing request filters")
assert isinstance(req_filters, list)
for f in req_filters:
assert isinstance(f, _CB)
assert isinstance(f, Callback)
r = self._try_cb(req, f.cb, *f.args, **f.kwargs)
if r:
send_errs[t] = ReturnCode.FILTER_ERROR
Expand Down Expand Up @@ -1343,7 +1304,7 @@ def _peer_goodbye(self, request: Message):
self.logger.debug(f"{self.my_info.fqcn}: agent for {peer_ep.name} is already gone")

# ack back
return new_message()
return Message()

def _receive_bulk_message(self, request: Message):
target_msgs = request.payload
Expand Down Expand Up @@ -1477,20 +1438,19 @@ def _process_request(self, origin: str, message: Message) -> Union[None, Message
self.logger.debug(f"{self.my_info.fqcn}: invoking incoming request filters")
assert isinstance(req_filters, list)
for f in req_filters:
assert isinstance(f, _CB)
assert isinstance(f, Callback)
reply = self._try_cb(message, f.cb, *f.args, **f.kwargs)
if reply:
return reply

assert isinstance(_cb, _CB)
assert isinstance(_cb, Callback)
self.logger.debug(f"{self.my_info.fqcn}: calling registered request CB")
cb_start = time.perf_counter()
reply = self._try_cb(message, _cb.cb, *_cb.args, **_cb.kwargs)
cb_end = time.perf_counter()
self.req_cb_stats_pool.record_value(category=self._stats_category(message), value=cb_end - cb_start)
if not reply:
# the CB doesn't have anything to reply
self.logger.debug("no reply is returned from the CB")
return None

if not isinstance(reply, Message):
Expand Down Expand Up @@ -1630,7 +1590,7 @@ def _process_reply(self, origin: str, message: Message, msg_type: str):
self.logger.debug(f"{self.my_info.fqcn}: invoking incoming reply filters")
assert isinstance(reply_filters, list)
for f in reply_filters:
assert isinstance(f, _CB)
assert isinstance(f, Callback)
self._try_cb(message, f.cb, *f.args, **f.kwargs)

for rid in req_ids:
Expand Down Expand Up @@ -1808,7 +1768,6 @@ def _process_received_msg(self, endpoint: Endpoint, connection: Connection, mess
reply = self._process_request(origin, message)

if not reply:
self.logger.debug(f"{self.my_info.fqcn}: don't send response - nothing to send")
self.received_msg_counter_pool.increment(
category=self._stats_category(message), counter_name=_CounterName.REPLY_NONE
)
Expand Down Expand Up @@ -1863,7 +1822,7 @@ def _process_received_msg(self, endpoint: Endpoint, connection: Connection, mess
self.logger.debug(f"{self.my_info.fqcn}: invoking outgoing reply filters")
assert isinstance(reply_filters, list)
for f in reply_filters:
assert isinstance(f, _CB)
assert isinstance(f, Callback)
r = self._try_cb(reply, f.cb, *f.args, **f.kwargs)
if r:
reply = r
Expand Down
Loading

0 comments on commit 4eaa7df

Please sign in to comment.